Caffe2 - Python API
A deep learning, cross platform ML framework
net_builder.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, context
9 from caffe2.python.task import Task, TaskGroup
10 
11 
13 class NetBuilder(object):
14  """
15  Scope-driven mechanism for building nets, loops and conditional blocks.
16  Example:
17  from caffe2.python.net_builder import NetBuilder, ops
18  with NetBuilder() as nb:
19  c = ops.Const(5)
20  d = ops.Const(0)
21  with ops.loop():
22  ops.stop_if(ops.LE([c, ops.Const(0)]))
23  ops.Add([c, ops.Const(-1)], [c])
24  with ops.If(ops.GE([c, ops.Const(3)])):
25  ops.Add([d, ops.Const(10)])
26  ops.Print(c, [])
27  ops.Print(d, [])
28  step = core.to_execution_step(nb)
29  """
30  def __init__(self, name=None, _stop_blob_required=False,
31  _stop_blob=None, _fullname=None):
32  nb = NetBuilder.current(required=False)
33  assert not _fullname or not name, 'Cannot set both _fullname and name'
34  self.name = _fullname or '/'.join(filter(lambda x: x, (
35  nb.name if nb else None, name)))
36  self._frozen = False
37  self._current_net = None
38  self._children = []
39  self._stop_blob = _stop_blob
40  self._stop_blob_required = _stop_blob_required
41 
42  def stop_blob(self):
43  """
44  Returns the BlobReference to the stop_blob of this NetBuilder.
45  If one is not yet available, creates one.
46  This function assumes that the stop_blob() will be used immediatelly
47  in the current net, so it doesn't initialize it if the current net is
48  the first of the builder.
49  """
50  if self._stop_blob is None:
51  net = self.current_net()
53  net.NextName('stop_blob'), net=net)
54  if self._current_net != self._children[0]:
55  self._children.insert(0, core.Net('stop_blob_init'))
56  self._children[0].Const(False, blob_out=self._stop_blob)
57  return self._stop_blob
58 
59  def stop_if(self, blob):
60  ops.Copy(blob, self.stop_blob())
61  self._current_net = None
62 
63  def _assert_mutable(self):
64  assert not self._frozen, (
65  'This NetBuilder (%s) has been built already.' % self.name)
66 
67  def add(self, child):
68  self._assert_mutable()
69  self._current_net = None
70  self._children.append(child)
71  # to-do : check it's not a dag net
72  if isinstance(child, core.Net):
73  self._current_net = child
74  return child
75 
76  def current_net(self, name=None):
77  self._assert_mutable()
78  if self._current_net is None or name is not None:
79  self.add(core.Net(name))
80  return self._current_net
81 
82  def freeze(self):
83  for child in self._children:
84  if hasattr(child, 'freeze'):
85  child.freeze()
86  self._current_net = None
87  self._frozen = True
88 
89  def get(self):
90  self.freeze()
91  return self._children
92 
93  def __exit__(self, etype, *args):
94  self.freeze()
95  if etype is not None:
96  return
97  assert (not self._stop_blob_required) or self._stop_blob is not None, (
98  'This NetBuilder (%s) requires a stop condition ' % self.name +
99  'to be set with `stop` or `stop_if`')
100 
101  def __str__(self):
102  return self.name or 'Un-named NetBuilder'
103 
104 
105 class Operations(object):
106  """
107  Operations to be used in the context of a NetBuilder.
108  """
109  def net(self, net=None, name=None):
110  """
111  Retrieves the current net, or add a new net to the builder.
112  Args:
113  net: If provided, add the given net to the active builder.
114  Else, returns the current Net or creates a new one as needed.
115  name: if provided, creates a new Net with given name and makes
116  it the new current net of the active builder. Cannot
117  be provided if net is provided.
118  """
119  assert name is None or net is None, (
120  'Cannot provide both `net` and `name`.')
121  if net is not None:
122  NetBuilder.current().add(net)
123  return net
124  return NetBuilder.current().current_net(name=name)
125 
126  def __getattr__(self, op_type):
127  """
128  Adds an operator call to the currently active Net.
129  """
130  if op_type.startswith('__'):
131  raise AttributeError()
132  # We want hasattr to work properly even if no context is active.
133  if NetBuilder.current(required=False) is None:
134  raise AttributeError('No active NetBuilder.')
135  return getattr(self.net(), op_type)
136 
137  def task_group(self):
138  """
139  Creates a local task group which will execute as the next step of
140  the current NetBuilder.
141  """
142  from caffe2.python import task
143  group = NetBuilder.current()
144  with task.Cluster():
145  with task.Node('local'):
146  tg = task.TaskGroup()
147  group.add(tg)
148  return tg
149 
150  def stop(self):
151  """
152  Stop execution of the current execution step.
153  Example:
154  ops.Print(a, 0)
155  ops.stop()
156  ops.Print(b, 0)
157  In the example, 'b' will never be printed.
158  """
159  return self.stop_if(ops.Const(True))
160 
161  def stop_if(self, blob):
162  """
163  Stop execution of the current execution step if the
164  condition `blob` is met.
165  Example:
166  ops.Print(a, 0)
167  ops.stop_if(ops.LE([x, ops.Const(0)]))
168  ops.Print(b, 0)
169  In the example, 'b' will only be printed if the value of scalar
170  tensor 'x' lower or equal to 0.
171  """
172  return NetBuilder.current().stop_if(blob)
173 
174  def loop(self, iters=None, name=None):
175  """
176  Creates a NetBuilder that will execute in a loop as the next step of
177  the current NetBuilder. If `iters` is provided, the loop will execute
178  for `iters` iterations and then stop. `iters` can be a constant or a
179  BlobReference. If `iters` is not provided, the loop will execute
180  until `ops.stop` or `ops.stop_if` is called.
181  Examples:
182  a = ops.Const(5)
183  with ops.loop():
184  ops.stop_if(ops.LE([a, ops.Const(0)]))
185  ops.Print(a, 0)
186  ops.Add([a, ops.Const(-1)], [a])
187  Above, 'a' will be printed 5 times, with values 5 to 1.
188 
189  with ops.loop(10) as loop:
190  ops.LogInfo(loop.iter())
191  This will print the numbers from 0 to 9.
192 
193  x = ops.Add([ops.Const(10), ops.Const(10)])
194  with ops.loop(x) as loop:
195  ops.LogInfo(loop.iter())
196  This will print the numbers from 0 to 19.
197  """
198  return NetBuilder.current().add(_Loop(iters, name=name))
199 
200  def stop_guard(self, has_stopped_blob=None, name=None):
201  """
202  Creates a NetBuilder that will execute once as the next step of the
203  current NetBuilder. After execution, a bool tensor will indicate
204  whether the inner execution was halted with `stop` or `stop_if`.
205  Example:
206  a = ops.Const(True)
207  with ops.stop_guard() as sg1:
208  ops.stop_if(a)
209  ops.Print(ops.Const('did not stop'))
210  b = ops.Const(False)
211  with ops.stop_guard() as sg2:
212  ops.stop_if(b)
213  ops.Print(ops.Const('did not stop'))
214  ops.Print(sg1.has_stopped(), [])
215  ops.Print(sg2.has_stopped(), [])
216  In the example, 'did not stop' will be printed once,
217  followed by True and False.
218  """
219  return NetBuilder.current().add(
220  _StopGuard(has_stopped_blob=has_stopped_blob, name=name))
221 
222  def If(self, cond, name=None):
223  """
224  Creates a NetBuilder that will execute once as the next step of the
225  current NetBuilder if the blob `cond` is True.
226  Example:
227  with ops.If(ops.Const(True)):
228  ops.Print(ops.Const('Will print'))
229  with ops.If(ops.Const(False)):
230  ops.Print(ops.Const('Wont print'))
231  The example will print 'Will print' once.
232  """
233  return NetBuilder.current().add(_RunIf(cond, name=name))
234 
235  def task_init(self):
236  """
237  Defines operations that will be executed once at task startup.
238  Useful when implementing processors, that don't have access to the Task
239  top-level structure.
240  Example:
241  def my_processor(rec):
242  with ops.task_init():
243  one = ops.Const(1)
244  two = ops.Const(1)
245  return Tuple(
246  ops.Add(rec[0](), zero), ops.Add(rec[1](), two))
247  """
248  setup = _SetupBuilder(_SetupBuilder.INIT)
249  self.net().add_attribute(Task.TASK_SETUP, setup)
250  return setup
251 
252  def task_exit(self):
253  """
254  Define operations to be executed at task shutdown.
255  Useful when implementing processors, that don't have access to the Task
256  top-level structure.
257  Example:
258  def read_queue(queue):
259  with ops.task_exit():
260  queue.close(ops.net())
261  return queue.read(ops.net())
262  """
263  setup = _SetupBuilder(_SetupBuilder.EXIT)
264  self.net().add_attribute(Task.TASK_SETUP, setup)
265  return setup
266 
267  def local_init(self):
268  """
269  Similar to `task_init`, but executes at TaskGroup's startup instead,
270  before any task of the group starts executing.
271  """
272  setup = _SetupBuilder(_SetupBuilder.INIT)
273  self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
274  return setup
275 
276  def local_exit(self):
277  """
278  Similar to `task_init`, but executes at TaskGroup's exit instead,
279  after all tasks of the group finished execution.
280  """
281  setup = _SetupBuilder(_SetupBuilder.EXIT)
282  self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
283  return setup
284 
285  def task_reporter(self, interval_ms=1000, name=None):
286  """
287  Define operations to be executed at every time interval from
288  task start-up to finish. These operations are guaranteed to
289  execute at least once after all other operations of the task are
290  finished.
291 
292  Example:
293  with ops.task_reporter(interval_ms=10000):
294  ops.LogInfo('10s elapsed')
295  """
296  return _ReporterBuilder(interval_ms, net=self.net(), name=name)
297 
298  def local_reporter(self, interval_ms=1000, name=None):
299  """
300  Similar to task_report, but operations defined within this block
301  will run repeatedly for as long as any of the tasks in the current
302  TaskGroup have not finished.
303  """
304  return _ReporterBuilder(interval_ms, name=name)
305 
306 
307 ops = Operations()
308 
309 
311  def __init__(self, interval_ms, net=None, name=None):
312  NetBuilder.__init__(self, name)
313  self._net = net
314  self.interval_ms = interval_ms
315 
316  def __exit__(self, etype, *args):
317  if etype is None:
318  step = core.to_execution_step(self)
319  step.RunEveryMillis(self.interval_ms)
320  if self._net:
321  self._net.add_attribute(Task.REPORT_STEP, step)
322  else:
323  TaskGroup.current().report_step(
324  step, interval_ms=self.interval_ms)
325  NetBuilder.__exit__(self, etype, *args)
326 
327 
329  INIT = 'init'
330  EXIT = 'exit'
331 
332  def __init__(self, type, name=None):
333  NetBuilder.__init__(self, name)
334  self.type = type
335 
336  def setup(self, net):
337  if self.type == _SetupBuilder.INIT:
338  return core.to_execution_step(self)
339 
340  def exit(self, net):
341  if self.type == _SetupBuilder.EXIT:
342  return core.to_execution_step(self)
343 
344 
346  def __init__(self, name=None):
347  NetBuilder.__init__(self, name)
348 
349  def __exit__(self, etype, *args):
350  if etype is None and self._stop_blob is not None:
351  ops.stop()
352  NetBuilder.__exit__(self, etype, *args)
353 
354 
356  def __init__(self, has_stopped_blob=None, name=None):
357  _RunOnce.__init__(self, name)
358  self._stopped = has_stopped_blob
359  self._ran = False
360 
361  def __enter__(self):
362  r = _RunOnce.__enter__(self)
363  self._stopped = ops.Const(True, blob_out=self._stopped)
364  return r
365 
366  def __exit__(self, etype, *args):
367  if etype is None:
368  self._ran = True
369  ops.Const(False, blob_out=self._stopped)
370  _RunOnce.__exit__(self, etype, *args)
371 
372  def has_stopped(self):
373  """
374  Return a blob that will be set to scalar bool `True` after
375  this net builder ran, iff it was halted early.
376  """
377  assert self._ran, 'Context not used yet.'
378  return self._stopped
379 
380 
382  def __init__(self, iters=None, name=None):
383  NetBuilder.__init__(self, name, _stop_blob_required=True)
384  if iters is not None:
385  self._inc = ops.Const(1)
386  self._iter = ops.Const(0)
387  self._num_iters = (
388  iters if isinstance(iters, core.BlobReference)
389  else ops.Const(iters))
390  else:
391  self._num_iters = None
392 
393  def iter(self):
394  assert self._num_iters is not None, (
395  'This loop does not have a number of iterations.')
396  assert self._iter is not None, (
397  'iter() must be called from inside the loop context')
398  return self._iter
399 
400  def __enter__(self):
401  builder = NetBuilder.__enter__(self)
402  if self._num_iters is not None:
403  ops.stop_if(ops.GE([self._iter, self._num_iters]))
404  return builder
405 
406  def __exit__(self, type, *args):
407  if type is None and self._num_iters is not None:
408  self.current_net().Add([self._iter, self._inc], [self._iter])
409  NetBuilder.__exit__(self, type, *args)
410 
411 
413  def __init__(self, cond_blob=None, name=None, _already_ran=None):
414  _RunOnce.__init__(self, name)
415  assert cond_blob or _already_ran
416  self._is_else = cond_blob is None
417  if _already_ran is None:
418  self._else_blob = ops.Not(cond_blob)
419  self._already_ran = ops.Const(False)
420  else:
421  self._already_ran = _already_ran
422  self._else_blob = _already_ran if cond_blob is None else (
423  ops.Or([_already_ran, ops.Not(cond_blob)]))
424 
425  def __enter__(self):
426  r = _RunOnce.__enter__(self)
427  ops.stop_if(self._else_blob)
428  ops.Const(True, blob_out=self._already_ran)
429  return r
430 
431  def Elif(self, cond, name=None):
432  assert not self._is_else, 'Else not allowed for an Else.'
433  return NetBuilder.current().add(_RunIf(
434  cond, name=name or self.name, _already_ran=self._already_ran))
435 
436  def Else(self, name=None):
437  assert not self._is_else, 'Elif not allowed for an Else.'
438  return NetBuilder.current().add(
439  _RunIf(name=name or self.name, _already_ran=self._already_ran))
def stop_if(self, blob)
Definition: net_builder.py:161
def task_reporter(self, interval_ms=1000, name=None)
Definition: net_builder.py:285
def loop(self, iters=None, name=None)
Definition: net_builder.py:174
def add(self, child)
Definition: net_builder.py:67
def local_reporter(self, interval_ms=1000, name=None)
Definition: net_builder.py:298
def __getattr__(self, op_type)
Definition: net_builder.py:126
def net(self, net=None, name=None)
Definition: net_builder.py:109
def _assert_mutable(self)
Definition: net_builder.py:63
def If(self, cond, name=None)
Definition: net_builder.py:222
def to_execution_step(step_or_nets, default_name=None)
Definition: core.py:1996
def current_net(self, name=None)
Definition: net_builder.py:76
def stop_guard(self, has_stopped_blob=None, name=None)
Definition: net_builder.py:200