3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core, context
9 from caffe2.python.task
import Task, TaskGroup
15 Scope-driven mechanism for building nets, loops and conditional blocks. 17 from caffe2.python.net_builder import NetBuilder, ops 18 with NetBuilder() as nb: 22 ops.stop_if(ops.LE([c, ops.Const(0)])) 23 ops.Add([c, ops.Const(-1)], [c]) 24 with ops.If(ops.GE([c, ops.Const(3)])): 25 ops.Add([d, ops.Const(10)]) 28 step = core.to_execution_step(nb) 30 def __init__(self, name=None, _stop_blob_required=False,
31 _stop_blob=None, _fullname=None):
32 nb = NetBuilder.current(required=
False)
33 assert not _fullname
or not name,
'Cannot set both _fullname and name' 34 self.
name = _fullname
or '/'.join(filter(
lambda x: x, (
35 nb.name
if nb
else None, name)))
44 Returns the BlobReference to the stop_blob of this NetBuilder. 45 If one is not yet available, creates one. 46 This function assumes that the stop_blob() will be used immediatelly 47 in the current net, so it doesn't initialize it if the current net is 48 the first of the builder. 53 net.NextName(
'stop_blob'), net=net)
59 def stop_if(self, blob):
63 def _assert_mutable(self):
65 'This NetBuilder (%s) has been built already.' % self.
name)
76 def current_net(self, name=None):
84 if hasattr(child,
'freeze'):
93 def __exit__(self, etype, *args):
98 'This NetBuilder (%s) requires a stop condition ' % self.
name +
99 'to be set with `stop` or `stop_if`')
102 return self.
name or 'Un-named NetBuilder' 107 Operations to be used in the context of a NetBuilder. 109 def net(self, net=None, name=None):
111 Retrieves the current net, or add a new net to the builder. 113 net: If provided, add the given net to the active builder. 114 Else, returns the current Net or creates a new one as needed. 115 name: if provided, creates a new Net with given name and makes 116 it the new current net of the active builder. Cannot 117 be provided if net is provided. 119 assert name
is None or net
is None, (
120 'Cannot provide both `net` and `name`.')
122 NetBuilder.current().add(net)
124 return NetBuilder.current().current_net(name=name)
128 Adds an operator call to the currently active Net. 130 if op_type.startswith(
'__'):
131 raise AttributeError()
133 if NetBuilder.current(required=
False)
is None:
134 raise AttributeError(
'No active NetBuilder.')
135 return getattr(self.
net(), op_type)
139 Creates a local task group which will execute as the next step of 140 the current NetBuilder. 142 from caffe2.python
import task
143 group = NetBuilder.current()
152 Stop execution of the current execution step. 157 In the example, 'b' will never be printed. 159 return self.
stop_if(ops.Const(
True))
163 Stop execution of the current execution step if the 164 condition `blob` is met. 167 ops.stop_if(ops.LE([x, ops.Const(0)])) 169 In the example, 'b' will only be printed if the value of scalar 170 tensor 'x' lower or equal to 0. 172 return NetBuilder.current().
stop_if(blob)
174 def loop(self, iters=None, name=None):
176 Creates a NetBuilder that will execute in a loop as the next step of 177 the current NetBuilder. If `iters` is provided, the loop will execute 178 for `iters` iterations and then stop. `iters` can be a constant or a 179 BlobReference. If `iters` is not provided, the loop will execute 180 until `ops.stop` or `ops.stop_if` is called. 184 ops.stop_if(ops.LE([a, ops.Const(0)])) 186 ops.Add([a, ops.Const(-1)], [a]) 187 Above, 'a' will be printed 5 times, with values 5 to 1. 189 with ops.loop(10) as loop: 190 ops.LogInfo(loop.iter()) 191 This will print the numbers from 0 to 9. 193 x = ops.Add([ops.Const(10), ops.Const(10)]) 194 with ops.loop(x) as loop: 195 ops.LogInfo(loop.iter()) 196 This will print the numbers from 0 to 19. 198 return NetBuilder.current().add(
_Loop(iters, name=name))
202 Creates a NetBuilder that will execute once as the next step of the 203 current NetBuilder. After execution, a bool tensor will indicate 204 whether the inner execution was halted with `stop` or `stop_if`. 207 with ops.stop_guard() as sg1: 209 ops.Print(ops.Const('did not stop')) 211 with ops.stop_guard() as sg2: 213 ops.Print(ops.Const('did not stop')) 214 ops.Print(sg1.has_stopped(), []) 215 ops.Print(sg2.has_stopped(), []) 216 In the example, 'did not stop' will be printed once, 217 followed by True and False. 219 return NetBuilder.current().add(
220 _StopGuard(has_stopped_blob=has_stopped_blob, name=name))
222 def If(self, cond, name=None):
224 Creates a NetBuilder that will execute once as the next step of the 225 current NetBuilder if the blob `cond` is True. 227 with ops.If(ops.Const(True)): 228 ops.Print(ops.Const('Will print')) 229 with ops.If(ops.Const(False)): 230 ops.Print(ops.Const('Wont print')) 231 The example will print 'Will print' once. 233 return NetBuilder.current().add(
_RunIf(cond, name=name))
237 Defines operations that will be executed once at task startup. 238 Useful when implementing processors, that don't have access to the Task 241 def my_processor(rec): 242 with ops.task_init(): 246 ops.Add(rec[0](), zero), ops.Add(rec[1](), two)) 249 self.
net().add_attribute(Task.TASK_SETUP, setup)
254 Define operations to be executed at task shutdown. 255 Useful when implementing processors, that don't have access to the Task 258 def read_queue(queue): 259 with ops.task_exit(): 260 queue.close(ops.net()) 261 return queue.read(ops.net()) 264 self.
net().add_attribute(Task.TASK_SETUP, setup)
269 Similar to `task_init`, but executes at TaskGroup's startup instead, 270 before any task of the group starts executing. 273 self.
net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
278 Similar to `task_init`, but executes at TaskGroup's exit instead, 279 after all tasks of the group finished execution. 282 self.
net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
287 Define operations to be executed at every time interval from 288 task start-up to finish. These operations are guaranteed to 289 execute at least once after all other operations of the task are 293 with ops.task_reporter(interval_ms=10000): 294 ops.LogInfo('10s elapsed') 300 Similar to task_report, but operations defined within this block 301 will run repeatedly for as long as any of the tasks in the current 302 TaskGroup have not finished. 311 def __init__(self, interval_ms, net=None, name=None):
312 NetBuilder.__init__(self, name)
316 def __exit__(self, etype, *args):
321 self.
_net.add_attribute(Task.REPORT_STEP, step)
323 TaskGroup.current().report_step(
325 NetBuilder.__exit__(self, etype, *args)
332 def __init__(self, type, name=None):
333 NetBuilder.__init__(self, name)
336 def setup(self, net):
337 if self.
type == _SetupBuilder.INIT:
341 if self.
type == _SetupBuilder.EXIT:
346 def __init__(self, name=None):
347 NetBuilder.__init__(self, name)
349 def __exit__(self, etype, *args):
350 if etype
is None and self.
_stop_blob is not None:
352 NetBuilder.__exit__(self, etype, *args)
356 def __init__(self, has_stopped_blob=None, name=None):
357 _RunOnce.__init__(self, name)
362 r = _RunOnce.__enter__(self)
366 def __exit__(self, etype, *args):
369 ops.Const(
False, blob_out=self.
_stopped)
370 _RunOnce.__exit__(self, etype, *args)
374 Return a blob that will be set to scalar bool `True` after 375 this net builder ran, iff it was halted early. 377 assert self.
_ran,
'Context not used yet.' 382 def __init__(self, iters=None, name=None):
383 NetBuilder.__init__(self, name, _stop_blob_required=
True)
384 if iters
is not None:
385 self.
_inc = ops.Const(1)
386 self.
_iter = ops.Const(0)
389 else ops.Const(iters))
395 'This loop does not have a number of iterations.')
396 assert self.
_iter is not None, (
397 'iter() must be called from inside the loop context')
401 builder = NetBuilder.__enter__(self)
406 def __exit__(self, type, *args):
407 if type
is None and self.
_num_iters is not None:
409 NetBuilder.__exit__(self, type, *args)
413 def __init__(self, cond_blob=None, name=None, _already_ran=None):
414 _RunOnce.__init__(self, name)
415 assert cond_blob
or _already_ran
417 if _already_ran
is None:
422 self.
_else_blob = _already_ran
if cond_blob
is None else (
423 ops.Or([_already_ran, ops.Not(cond_blob)]))
426 r = _RunOnce.__enter__(self)
431 def Elif(self, cond, name=None):
432 assert not self.
_is_else,
'Else not allowed for an Else.' 433 return NetBuilder.current().add(
_RunIf(
436 def Else(self, name=None):
437 assert not self.
_is_else,
'Elif not allowed for an Else.' 438 return NetBuilder.current().add(
def task_reporter(self, interval_ms=1000, name=None)
def loop(self, iters=None, name=None)
def local_reporter(self, interval_ms=1000, name=None)
def __getattr__(self, op_type)
def net(self, net=None, name=None)
def _assert_mutable(self)
def If(self, cond, name=None)
def to_execution_step(step_or_nets, default_name=None)
def current_net(self, name=None)
def stop_guard(self, has_stopped_blob=None, name=None)