Caffe2 - Python API
A deep learning, cross platform ML framework
core.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from collections import namedtuple
9 from collections import OrderedDict
10 
11 from caffe2.proto import caffe2_pb2
12 from collections import defaultdict
13 from caffe2.python import scope, utils, workspace
14 import caffe2.python._import_c_extension as C
15 import numpy as np
16 import sys
17 
18 
19 # Mac os specific message
20 if (sys.platform == 'darwin' and 'leveldb' in C.registered_dbs()):
21  print('If you are using homebrew leveldb on a Mac OS, you might see an '
22  'error warning you that malloc_zone_unregister() failed. This is '
23  'not a caffe2 issue but is due to the homebrew leveldb having an '
24  'incompatible memory allocator. It does not affect usage.')
25 
26 # Convenience redirections to functions inside scope.
27 DeviceScope = scope.DeviceScope
28 NameScope = scope.NameScope
29 
30 
31 # Bring datatype enums to the main namespace
32 class DataType:
33  pass
34 
35 
36 def _InitDataType():
37  for name, value in caffe2_pb2.TensorProto.DataType.items():
38  setattr(DataType, name, value)
39 
40 
41 _InitDataType()
42 
43 # Python 2 and 3 compatibility: test if basestring exists
44 try:
45  basestring = basestring # NOQA
46 except NameError:
47  # This is python3 so we define basestring.
48  basestring = str
49 
50 
51 def _GetRegisteredOperators():
52  return set(workspace.RegisteredOperators())
53 
54 
55 _REGISTERED_OPERATORS = _GetRegisteredOperators()
56 
57 
58 def RefreshRegisteredOperators():
59  global _REGISTERED_OPERATORS
60  _REGISTERED_OPERATORS = _GetRegisteredOperators()
61 
62 
63 _GLOBAL_INIT_ARGS = []
64 
65 
66 def GlobalInit(args):
67  _GLOBAL_INIT_ARGS.extend(args[1:])
68  C.global_init(args)
69 
70 
71 def GetGlobalInitArgs():
72  return _GLOBAL_INIT_ARGS[:]
73 
74 
75 _WORKER_INIT_CALLS = []
76 
77 
78 def worker_init_func(func):
79  """
80  By decorating a function with this, each call to the function will be
81  recorded at workflow time and replayed in each of the works at startup.
82  Used for example for registering caffe python operators.
83  """
84  def call(*args, **kwargs):
85  _WORKER_INIT_CALLS.append((func, args, kwargs))
86  return func(*args, **kwargs)
87 
88  return call
89 
90 
91 def GetWorkerInitCalls():
92  return _WORKER_INIT_CALLS[:]
93 
94 
95 def IsOperator(op_type):
96  return (op_type in _REGISTERED_OPERATORS)
97 
98 
99 def IsOperatorWithEngine(op_type, engine):
100  return (op_type + "_ENGINE_" + engine in _REGISTERED_OPERATORS)
101 
102 
103 def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None):
104  option = caffe2_pb2.DeviceOption()
105  option.device_type = device_type
106  option.cuda_gpu_id = cuda_gpu_id
107  if random_seed is not None:
108  option.random_seed = random_seed
109  return option
110 
111 
112 GradientSlice = namedtuple('GradientSlice', ['indices', 'values'])
113 
114 
115 class BlobReference(object):
116  """A wrapper around a blob in a net.
117 
118  BlobReference gives us a way to refer to the network that the blob is
119  generated from. Note that blobs are, essentially, just strings in the
120  current workspace.
121  """
122 
123  def __init__(self, name, net=None):
124  """Initializes a blob reference.
125 
126  Note that this does not prepends the namescope. If needed, use
127  ScopedBlobReference() to prepend the existing namespace.
128  """
129  self._name = name
130  self._from_net = net
131  # meta allows helper functions to put whatever metainformation needed
132  # there.
133  self.meta = {}
134 
135  def __hash__(self):
136  return hash(self._name)
137 
138  def __eq__(self, other):
139  if isinstance(other, basestring):
140  return self._name == other
141  elif isinstance(other, BlobReference):
142  return self._name == other._name
143  else:
144  return False
145 
146  def __ne__(self, other):
147  return not(self == other)
148 
149  def __str__(self):
150  return self._name
151 
152  def __repr__(self):
153  return 'BlobReference("{}")'.format(self._name)
154 
155  def __add__(self, other):
156  if not isinstance(other, basestring):
157  raise RuntimeError('Cannot add BlobReference to a non-string.')
158  return BlobReference(self._name + other, self._from_net)
159 
160  def __radd__(self, other):
161  if not isinstance(other, basestring):
162  raise RuntimeError('Cannot add a non-string to BlobReference.')
163  return BlobReference(other + self._name, self._from_net)
164 
165  def Net(self):
166  return self._from_net
167 
168  def GetNameScope(self):
169  return self._name[:self._name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
170 
171  def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
172  """Internal function that routes the operator generation to the
173  network's __getattr__ function.
174  """
175  inputs = [] if inputs is None else inputs
176  if isinstance(inputs, BlobReference) or isinstance(inputs, str):
177  inputs = [inputs]
178  # add self to the input list.
179  inputs.insert(0, self)
180  return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs)
181 
182  def __getattr__(self, op_type):
183  """A wrapper allowing one to initiate operators from a blob reference.
184 
185  Example: for a blob reference b that comes from network n, doing
186  b.Relu(...)
187  is equivalent to doing
188  net.Relu([b], ...)
189  """
190  if op_type.startswith('__'):
191  raise AttributeError('Attribute {} not found.'.format(op_type))
192  if self._from_net is None:
193  raise RuntimeError(
194  'You cannot use a blob reference that does not have a net '
195  'source to create operators. Create the operator from an '
196  'explicit net object.')
197  if not IsOperator(op_type):
198  raise RuntimeError(
199  'Method ' + op_type + ' is not a registered operator.' +
200  ' Did you mean: [' +
201  ",".join(workspace.C.nearby_opnames(op_type)) + ']'
202  )
203  return lambda *args, **kwargs: self._CreateAndAddToNet(
204  op_type, *args, **kwargs)
205 
206 
207 def ScopedName(name):
208  """prefix the name with the current scope."""
209  return scope.CurrentNameScope() + name
210 
211 
212 def ScopedBlobReference(name, *args, **kwargs):
213  """Returns a blob reference with scope prefixed."""
214  return BlobReference(ScopedName(name), *args, **kwargs)
215 
216 
217 def _RectifyInputOutput(blobs, net=None):
218  """A helper function to rectify the input or output of the CreateOperator
219  interface.
220  """
221  if isinstance(blobs, basestring):
222  # If blobs is a single string, prepend scope.CurrentNameScope()
223  # and put it as a list.
224  # TODO(jiayq): enforce using BlobReference instead of raw strings.
225  return [ScopedBlobReference(blobs, net=net)]
226  elif type(blobs) is BlobReference:
227  # If blob is a BlobReference, simply put it as a list.
228  return [blobs]
229  elif type(blobs) in (list, tuple):
230  # If blob is a list, we go through it and type check.
231  rectified = []
232  for blob in blobs:
233  if isinstance(blob, basestring):
234  rectified.append(ScopedBlobReference(blob, net=net))
235  elif type(blob) is BlobReference:
236  rectified.append(blob)
237  else:
238  raise TypeError(
239  "I/O blob #{} of unsupported type: {} of type {}"
240  .format(len(rectified), str(blob), type(blob)))
241  return rectified
242  else:
243  raise TypeError(
244  "Unknown input/output type: %s of type %s." %
245  (str(blobs), type(blobs))
246  )
247 
248 
249 def CreateOperator(
250  operator_type,
251  inputs,
252  outputs,
253  name='',
254  control_input=None,
255  device_option=None,
256  arg=None,
257  engine=None,
258  **kwargs
259 ):
260  """A function wrapper that allows one to create operators based on the
261  operator type. The type should be a string corresponding to an operator
262  registered with Caffe2.
263  """
264  operator = caffe2_pb2.OperatorDef()
265  operator.type = operator_type
266  operator.name = name
267  # Add rectified inputs and outputs
268  inputs = _RectifyInputOutput(inputs)
269  outputs = _RectifyInputOutput(outputs)
270  operator.input.extend([str(i) for i in inputs])
271  operator.output.extend([str(o) for o in outputs])
272  if control_input:
273  control_input = _RectifyInputOutput(control_input)
274  operator.control_input.extend([str(i) for i in control_input])
275  # Set device option:
276  # (1) If device_option is explicitly set, use device_option.
277  # (2) If not, but scope.CurrentDeviceScope() is set,
278  # then we use scope.CurrentDeviceScope().
279  # (3) Otherwise, do not set device option.
280  if device_option is not None:
281  operator.device_option.CopyFrom(device_option)
282  elif scope.CurrentDeviceScope() is not None:
283  operator.device_option.CopyFrom(scope.CurrentDeviceScope())
284  if engine is not None:
285  operator.engine = engine
286  # random seed is defined in the device option, so we need to do special
287  # care.
288  if 'random_seed' in kwargs:
289  operator.device_option.random_seed = kwargs['random_seed']
290  del kwargs['random_seed']
291  # Add given arguments that do not need parsing
292  if arg is not None:
293  operator.arg.extend(arg)
294  # Add all other arguments
295  for key, value in kwargs.items():
296  operator.arg.add().CopyFrom(utils.MakeArgument(key, value))
297 
300  return operator
301 
302 
303 def _RegisterPythonImpl(f, grad_f=None, pass_workspace=False):
304  if isinstance(f, tuple):
305  f = f[0](*f[1], **f[2])
306  if isinstance(grad_f, tuple):
307  grad_f = grad_f[0](*grad_f[1], **grad_f[2])
308 
309  token = C.register_python_op(f, pass_workspace)
310  if grad_f:
311  C.register_python_gradient_op(token, grad_f)
312  return token
313 
314 
316  f, inputs,
317  outputs,
318  grad_f=None,
319  pass_workspace=False,
320  *args,
321  **kwargs
322 ):
323  """
324  `f` should have a signature (inputs, outputs)
325 
326  If `pass_workspace` is True, the signature is changed to
327  (inputs, outputs, workspace) where `workspace` is the workspace the op
328  is going to run on. This is potentially dangerous (as the op can manipulate
329  the workspace directly), use on your own risk.
330  """
331  kwargs["token"] = _RegisterPythonImpl(
332  f, grad_f, pass_workspace=pass_workspace
333  )
334  return CreateOperator("Python", inputs, outputs, *args, **kwargs)
335 
336 
337 def GetIndexFromGradientList(g_list, name):
338  """A helper function to get the index from a gradient list, None if not
339  matching."""
340  for i, g in enumerate(g_list):
341  if g == name:
342  return i
343  elif type(g) is GradientSlice:
344  if (g.indices == name or g.values == name):
345  return i
346  return None
347 
348 
349 OpSSA = namedtuple('OpSSA', ['op', 'in_versions', 'out_versions'])
350 GradGenMeta = namedtuple('GradGenMeta', ['grad_op', 'idx', 'gradient'])
351 SparseGradGenMeta = namedtuple('SparseGradGenMeta', [
352  'grad_op_indices', 'idx_indices',
353  'grad_op_values', 'idx_values',
354  'gradient',
355 ])
356 
357 
358 class IR(object):
359  """A simple IR class to keep track of all intermediate representations used
360  in the gradient computation.
361  """
362 
363  def __init__(self, operators):
364  # The IR class holds multiple metadata from the forward pass:
365  # a) ssa: a list of [op, in_versions, out_versions] recording the
366  # input and the output version of each operator, similar
367  # to a normal SSA form.
368  # b) input_count: a dictionary specifying for each blob and
369  # each of its version, how many times it is used as input for another
370  # op.
371  # c) frontier: maintaining the current versions of the blobs
372  # we are having in the workspace, after the execution of all the ops
373  # added to the IR so far. This is useful because if a gradient is
374  # trying to access an earlier version of a blob, we can sanity check
375  # that it is no longer there, and thus throw an error.
376  # d) gradient_frontier: maps the names of blobs to its version that the
377  # gradient corresponds to.
378  # e) gradient_generators: for each blob and each of its version, maps to
379  # a list of operators that generates its gradient together with the
380  # gradient name.
381  self.ssa = []
382  self.input_usages = defaultdict(lambda: defaultdict(list))
383  self.frontier = defaultdict(int)
384  self.gradient_frontier = {}
385  self.gradient_generators = defaultdict(lambda: defaultdict(list))
386 
387  for op in operators:
388  self.Play(op)
389 
390  def Play(self, op):
391  """"Adds an op to the current IR, and update the internal states to
392  reflect the blobs and versions after the execution of the op.
393  """
394  # For input, they are the current version in the dict.
395  in_versions = {}
396  for s in op.input:
397  in_versions[s] = self.frontier[s]
398  self.input_usages[s][self.frontier[s]].append(len(self.ssa))
399  # For output, they are the current version plus one. If this is a
400  # newly created blob, its version starts with zero.
401  out_versions = {}
402  for s in op.output:
403  if s in self.frontier:
404  self.frontier[s] += 1
405  out_versions[s] = self.frontier[s]
406  # Add to SSA for bookkeeping.
407  self.ssa.append(OpSSA(op, in_versions, out_versions))
408 
410  self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
411  """Checks if the gradient operators can be correctly carried out."""
412  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
413  original_index = GetIndexFromGradientList(g_output, grad_op_input)
414  # If it is a dense or sparse gradient name, it should match the
415  # version of the corresponding output.
416  if original_index is not None:
417  original_name = forward_op.output[original_index]
418  if (out_versions[original_name] !=
419  self.gradient_frontier[original_name]):
420  raise RuntimeError(
421  'Gradient name "%s" is expected to correspond '
422  'to version %d of "%s", but currently we have '
423  'version %d.' % (
424  grad_op_input, out_versions[original_name],
425  original_name,
426  self.gradient_frontier[original_name]))
427  # If it is an output name, the current version should match the
428  # version when the operator was run.
429  elif grad_op_input in out_versions:
430  if self.frontier[grad_op_input] != out_versions[grad_op_input]:
431  raise RuntimeError(
432  'Gradient operator needs output "%s" at version'
433  ' %d, but currently we have version %d.' % (
434  grad_op_input, out_versions[grad_op_input],
435  self.frontier[grad_op_input]
436  )
437  )
438  # If it is an input name, the current version should match the
439  # version when the operator was run.
440  elif grad_op_input in in_versions:
441  if (self.frontier[grad_op_input] != in_versions[grad_op_input]):
442  raise RuntimeError(
443  'Gradient operator needs input "%s" at version '
444  '%d, but currently we have version %d.' % (
445  grad_op_input, in_versions[grad_op_input],
446  self.frontier[grad_op_input]
447  )
448  )
449  # If it is none of the above, it should be a blob that is
450  # generated locally by one of the previous gradient operators.
451  else:
452  if grad_op_input not in locally_generated_blobs:
453  raise RuntimeError(
454  'Blob name "%s" not in the scope of operator: '
455  '%s\nand is not generated by any of the local '
456  'gradient operators.' % (grad_op_input, str(forward_op))
457  )
458 
459  def AppendSparseGenerators(self, sparse_generators):
460  # merge indices and values generators for sparse gradients
461  for name, input_generators in sparse_generators.items():
462  for version, generators in input_generators.items():
463  if len(generators) == 1:
464  # either indices or values are generated (but not both)
465  generator = generators[0]
466  else:
467  # both indices and values are generated
468  assert(len(generators) == 2)
469  op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0]
470  op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1]
471  assert(g1 == g2)
472  assert(op1_i is None or op2_i is None)
473  assert(op1_v is None or op2_v is None)
474  assert(idx1_i == 0 or idx2_i == 0)
475  assert(idx1_v == 0 or idx2_v == 0)
476  generator = SparseGradGenMeta(
477  op1_i or op2_i, idx1_i + idx2_i,
478  op1_v or op2_v, idx1_v + idx2_v,
479  g1)
480  self.gradient_generators[name][version].append(generator)
481 
482  def BuildGradientGenerators( # NOQA
483  self, fwd_op_idx, gradient_ops, g_output, g_input):
484  """Updates gradient_generators and gradient_frontier"""
485  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
486  locally_generated_blobs = []
487  sparse_generators = defaultdict(lambda: defaultdict(list))
488 
489  for grad_op in gradient_ops:
490  # (1) check that inputs are valid
491  for s in grad_op.input:
493  s, g_output, fwd_op_idx, locally_generated_blobs)
494 
495  # (2) add outputs to the locally generated blobs
496  # If an output corresponds to the gradient of an input, we also
497  # record it to gradient_generators
498  locally_generated_blobs.extend([str(s) for s in grad_op.output])
499  for i, output in enumerate(grad_op.output):
500  input_index = GetIndexFromGradientList(g_input, output)
501  if input_index is not None:
502  input_name = forward_op.input[input_index]
503  input_version = in_versions[input_name]
504  g = g_input[input_index]
505  if type(g) is GradientSlice:
506  # the output corresponds either to the indices or the
507  # values of the sparse gradient. In either case we
508  # create a (partial) SparseGradGenMeta. If necessary,
509  # we'll merge indices and values generators
510  # corresponding to the same gradient in step (3)
511  if g.indices == output:
512  m = SparseGradGenMeta(grad_op, i, None, 0, g)
513  else:
514  assert(g.values == output)
515  m = SparseGradGenMeta(None, 0, grad_op, i, g)
516  sparse_generators[input_name][input_version].append(m)
517  else:
518  self.gradient_generators[input_name][input_version] \
519  .append(GradGenMeta(
520  grad_op, i, g))
521 
522  # (3) merge indices and values generators for sparse gradients, and
523  # add them to gradient_generators
524  self.AppendSparseGenerators(sparse_generators)
525 
526  # (4) for ops (e.g., Add, Sum, Sub) which have gradient outputs directly
527  # passed from inputs (not computed from gradient ops), we create an
528  # GradGenMeta with None grad_op and idx so that the gradient_generators
529  # knows where the gradients are coming from. This is needed for creating
530  # Sum op to accumulate the gradients from multiple parents.
531  for input_index, g in enumerate(g_input):
532  input_name = forward_op.input[input_index]
533  input_version = in_versions[input_name]
534  if not g:
535  continue
536  if type(g) is GradientSlice:
537  if str(g.indices) not in locally_generated_blobs and \
538  str(g.values) not in locally_generated_blobs:
539  self.gradient_generators[input_name][input_version].append(
540  SparseGradGenMeta(None, 0, None, 0, g))
541  else:
542  if str(g) not in locally_generated_blobs:
543  self.gradient_generators[input_name][input_version].append(
544  GradGenMeta(None, 0, g))
545 
546  # Finally, for the gradients specified in g_input, we update the
547  # gradient frontier to reflect the input versions that the gradients
548  # correspond to.
549  for i, g in enumerate(g_input):
550  if g is not None:
551  input_name = forward_op.input[i]
552  input_version = in_versions[input_name]
553  self.gradient_frontier[input_name] = input_version
554 
555  def _GetSumOpOutputName(self, generator, input_name):
556  def remove_suffix(s, suffix):
557  if s.endswith(suffix):
558  return s[:-len(suffix)]
559  return s
560 
561  for g in generator:
562  if type(g) is GradGenMeta:
563  grad_op, idx, _ = g
564  if grad_op:
565  return grad_op.output[idx]
566  else:
567  assert(type(g) is SparseGradGenMeta)
568  op_i, idx_i, op_v, idx_v, _ = g
569  if op_i:
570  return remove_suffix(op_i.output[idx_i], '_indices')
571  if op_v:
572  return remove_suffix(op_v.output[idx_v], '_values')
573 
574  return input_name + '_grad'
575 
576  def _SetSumOpsDeviceOption(self, sum_ops, generators):
577  # we already checked that device options are consistent so we can just
578  # use the first one we find
579  for generator in generators:
580  grad_op = generator.grad_op if type(generator) is GradGenMeta \
581  else generator.grad_op_values or generator.grad_op_indices
582  if grad_op:
583  if grad_op.HasField('device_option'):
584  for op in sum_ops:
585  op.device_option.CopyFrom(grad_op.device_option)
586  break
587 
588  def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
589  grad_op.output[idx] = (
590  '_' + grad_op.output[idx] + '_autosplit_{}'.format(cnt))
591  return grad_op.output[idx], cnt + 1
592 
593  def _CheckSumOpsConflict(self, out_base_name, g):
594  if str(out_base_name) == str(g):
595  # TODO not sure what this message really means
596  raise RuntimeError(
597  'The gradient output of empty gradient op can not '
598  'be the same as the normal name of the current '
599  'input gradient.')
600 
601  def _MakeDenseSumOps(self, generators, out_base_name):
602  sum_op_input = []
603  cnt = 0
604 
605  for generator in generators:
606  grad_op, idx, g = generator
607  assert(type(g) is not GradientSlice)
608  if grad_op:
609  out, cnt = self._DisambiguateGradOpOutput(grad_op, idx, cnt)
610  sum_op_input.append(out)
611  else:
612  self._CheckSumOpsConflict(out_base_name, g)
613  sum_op_input.append(str(g))
614 
615  sum_ops = [CreateOperator(
616  "Sum",
617  map(BlobReference, sum_op_input),
618  BlobReference(out_base_name))]
619  return sum_ops, out_base_name
620 
621  def _MakeSparseSumOps(self, generators, out_base_name):
622  indices_concat_input = []
623  values_concat_input = []
624  cnt_i = 0
625  cnt_v = 0
626 
627  for generator in generators:
628  assert(type(generator) is SparseGradGenMeta)
629  op_i, idx_i, op_v, idx_v, g = generator
630  if op_i:
631  out, cnt_i = self._DisambiguateGradOpOutput(op_i, idx_i, cnt_i)
632  indices_concat_input.append(out)
633  else:
634  self._CheckSumOpsConflict(out_base_name, g.indices)
635  indices_concat_input.append(g.indices)
636  if op_v:
637  out, cnt_v = self._DisambiguateGradOpOutput(op_v, idx_v, cnt_v)
638  values_concat_input.append(out)
639  else:
640  self._CheckSumOpsConflict(out_base_name, g.values)
641  values_concat_input.append(g.values)
642 
643  indices_concat_output = out_base_name + '_indices_concat'
644  indices_concat_split = out_base_name + '_indices_concat_split'
645  values_concat_output = out_base_name + '_values_concat'
646  values_concat_split = out_base_name + '_values_concat_split'
647  # Sum the given sparse representations by simply concatenating the
648  # indices (resp. values) tensors together. We don't do any deduplication
649  # of indices at this point. This will be done as needed before the
650  # optimizer is called
651  sum_ops = [
653  "Concat",
654  map(BlobReference, indices_concat_input),
655  map(BlobReference,
656  [indices_concat_output, indices_concat_split]),
657  axis=0
658  ),
660  "Concat",
661  map(BlobReference, values_concat_input),
662  map(BlobReference, [values_concat_output, values_concat_split]),
663  axis=0
664  ),
665  ]
666  sum_op_output = GradientSlice(
667  indices=indices_concat_output,
668  values=values_concat_output,
669  )
670  return sum_ops, sum_op_output
671 
672  def _MakeSumOps(self, input_name, input_version):
673  generators = self.gradient_generators[input_name][input_version]
674  out_base_name = self._GetSumOpOutputName(generators, input_name)
675  types = list(set(type(x) for x in generators))
676  assert(len(types) == 1)
677  if types[0] is GradGenMeta:
678  sum_ops, g = self._MakeDenseSumOps(generators, out_base_name)
679  else:
680  assert(types[0] is SparseGradGenMeta)
681  sum_ops, g = self._MakeSparseSumOps(generators, out_base_name)
682  self._SetSumOpsDeviceOption(sum_ops, generators)
683  return sum_ops, g
684 
685  def _VerifyGradientGenerators(self, generator):
686  # (1) check if all gradients are of the same type. Aggregating a mix of
687  # sparse and dense gradients is not supported yet
688  if len({type(g) for g in generator}) > 1:
689  raise RuntimeError(
690  'Automatic aggregation of a mix of sparse and dense gradients '
691  'is not supported yet')
692 
693  # If for all the operators that used the operator, none or only one
694  # produced the gradient, then no additional sum needs to be carried
695  # out.
696  if len(generator) < 2:
697  return False
698 
699  all_gradient_names = []
700  all_device_options = []
701  for g in generator:
702  if type(g) is GradGenMeta:
703  if g.grad_op:
704  all_gradient_names.append(g.gradient)
705  all_device_options.append(g.grad_op.device_option)
706  else:
707  assert(type(g) is SparseGradGenMeta)
708  if g.grad_op_indices:
709  all_device_options.append(g.grad_op_indices.device_option)
710  if g.grad_op_values:
711  all_device_options.append(g.grad_op_values.device_option)
712  all_gradient_names.append(g.gradient.values)
713 
714  # Check if all grad names are the same.
715  if len(set(all_gradient_names)) > 1:
716  raise RuntimeError('Unexpected behavior: not all grad output '
717  'names are the same.')
718  # Check if all grad op device options are the same.
719  if len(all_device_options) >= 2 and not all(
720  d == all_device_options[0] for d in all_device_options[1:]):
721  raise RuntimeError('Unexpected behavior: not all grad ops'
722  'have the same device option.')
723  return True
724 
725  def DoGradientAccumulation(self, fwd_op_idx):
726  """For each input name in the forward op, check if we will need to
727  add gradient accumulation. If so, do gradient accumulation and return
728  the list of gradient operators.
729 
730  The criteria for doing gradient accumulation is:
731  (1) the specific input version has been used by multiple operators.
732  (2) the current fwd_op_idx is the first to use that input, i.e. in the
733  backward pass, is the last to optionally generate the gradient for
734  the op.
735  (3) For the operators that used the input, their gradient operators
736  have generated more than 1 gradient.
737 
738  When accumulating operators, our current solution is to rename all the
739  created gradients with an internal intermediate name, and then add a
740  Sum() operator that adds up all the gradients. This may use more memory
741  due to intermediate storage, but is usually the fastest approach as one
742  can do one single sum for multiple intermediate gradients.
743  """
744  forward_op, in_versions, out_versions = self.ssa[fwd_op_idx]
745  additional_sum_ops = []
746  grad_map = {}
747  for _i, input_name in enumerate(set(forward_op.input)):
748  input_version = in_versions[input_name]
749  input_usage = self.input_usages[input_name][input_version]
750  if (len(input_usage) <= 1 or fwd_op_idx != input_usage[0]):
751  # We do not need to do gradient accumulation yet.
752  continue
753  generator = self.gradient_generators[input_name][input_version]
754  try:
755  if not self._VerifyGradientGenerators(generator):
756  continue
757  except RuntimeError as err:
758  raise RuntimeError(
759  "Gradients for param ''{}'' failed to verify: {}".format(
760  input_name,
761  err
762  )
763  )
764 
765  # Finally, let's create the sum operator.
766  sum_ops, g = self._MakeSumOps(input_name, input_version)
767  additional_sum_ops.extend(sum_ops)
768  grad_map[input_name] = g
769  return additional_sum_ops, grad_map
770 
771  def _GetInitGradients(self, ys):
772  input_to_grad = {}
773  gradient_ops = []
774  for y, g in ys.items():
775  if g is None:
776  autograd_op = CreateOperator(
777  "ConstantFill", [y], [str(y) + "_autogen_grad"],
778  value=1.0)
779  gradient_ops.append(autograd_op)
780  g = autograd_op.output[0]
781  # Since the C++ gradient registry does not have notion of
782  # NameScopes, we will convert all references to strings.
783  input_to_grad[str(y)] = (
784  GradientSlice(str(g[0]), str(g[1]))
785  if isinstance(g, GradientSlice) else str(g))
786 
787  return input_to_grad, gradient_ops
788 
789  def _GenerateGradientsForForwardOp(
790  self, forward_op_idx, input_to_grad):
791  new_input_to_grad = {}
792  gradient_ops = []
793  forward_op, in_versions, out_versions = self.ssa[forward_op_idx]
794  g_output = list(
795  input_to_grad.get(name, None) for name in forward_op.output)
796  if not all(g is None for g in g_output):
797  gradient_ops, g_input = GradientRegistry.GetGradientForOp(
798  forward_op, g_output)
799  # Check if the gradient operators are legal, and update
800  # gradient_generators and gradient_frontier
802  forward_op_idx, gradient_ops, g_output, g_input)
803  # Record the gradient map to all_input_to_grad.
804  for name, grad in zip(forward_op.input, g_input):
805  # Do not overwrite an existing gradient with a None
806  # unless the input is also an output of the op, since
807  # we update the blob version when blob is output of an
808  # operator.
809  if grad is not None or \
810  name not in input_to_grad or \
811  name in list(forward_op.output):
812  new_input_to_grad[name] = grad
813 
814  return new_input_to_grad, gradient_ops
815 
816  def GetBackwardPass(self, ys):
817  """Gets the backward pass that computes the derivatives of given blobs.
818 
819  Inputs:
820  ys: a list or a dictionary specifying what blobs we want to compute
821  derivatives of. If the input is a list, we will automatically
822  generate their gradients with all-one values; if the input is a
823  dictionary, for any dictionary entries that are not None, we will
824  take the corresponding blobs as their gradients; for all those
825  that are None, we will auto-fill them with 1.
826  """
827  if isinstance(ys, list):
828  ys = dict((y, None) for y in ys)
829  elif not isinstance(ys, dict):
830  raise TypeError("ys should either be a list or a dict.")
831 
832  # Set the gradient frontier with the initialized external
833  # gradients.
834  for y, _ in ys.items():
835  self.gradient_frontier[y] = self.frontier[y]
836 
837  all_input_to_grad, all_gradient_ops = self._GetInitGradients(ys)
838 
839  # (2) Now, after having the virtual play above, we now play the ops
840  # backwards, creating the gradients along the path. Note that although
841  # we are playing it backwards, we cannot refer to variables that are
842  # at a version older than current_versions because it is already been
843  # overwritten.
844  for forward_op_idx in reversed(range(len(self.ssa))):
845  input_to_grad, gradient_ops = self._GenerateGradientsForForwardOp(
846  forward_op_idx, all_input_to_grad)
847  all_input_to_grad.update(input_to_grad)
848  all_gradient_ops += gradient_ops
849 
850  # If there are multiple use blobs, do gradient accumulation.
851  additional_sum_ops, grad_map = self.DoGradientAccumulation(
852  forward_op_idx)
853  # This line is so that if in an accumulation some of the operators
854  # have not produced gradients, they still do not overwrite the
855  # general all_input_to_grad map.
856  all_input_to_grad.update(grad_map)
857  all_gradient_ops += additional_sum_ops
858 
859  # (3) Post-processing.
860  # After we have done computation for each op, we now have the gradient
861  # operators ready. For the output map, we will convert everything to
862  # BlobReferences for easier handling in python.
863  all_input_to_grad_out = {}
864  for key, val in all_input_to_grad.items():
865  if val is not None:
866  all_input_to_grad_out[BlobReference(key)] = (
867  BlobReference(val) if isinstance(val, basestring) else
868  GradientSlice(BlobReference(val[0]), BlobReference(val[1])))
869  return all_gradient_ops, all_input_to_grad_out
870 
871 
872 class GradientRegistry(object):
873  """GradientRegistry holds the mapping from operators to their gradients."""
874  gradient_registry_ = {}
875 
876  @classmethod
877  def RegisterGradient(cls, op_type):
878  """A decorator for registering gradient mappings."""
879 
880  def Wrapper(func):
881  cls.gradient_registry_[op_type] = func
882  return func
883 
884  return Wrapper
885 
886  @classmethod
887  def _GetGradientForOpCC(cls, op_def, g_output):
888  # TODO(tulloch) - Propagate GradientWrapper up through the stack.
889  def from_untyped(grad):
890  if grad is None:
891  w = C.GradientWrapper()
892  assert w.is_empty()
893  return w
894  try:
895  (indices, values) = grad
896  w = C.GradientWrapper()
897  w.indices = indices
898  w.values = values
899  assert w.is_sparse()
900  return w
901  except ValueError:
902  w = C.GradientWrapper()
903  w.dense = grad
904  assert w.is_dense()
905  return w
906 
907  g_output = [from_untyped(grad) for grad in g_output]
908  grad_defs_str, g_input = C.get_gradient_defs(
909  op_def.SerializeToString(), g_output)
910 
911  def to_untyped(grad_wrapper):
912  if grad_wrapper.is_empty():
913  return None
914  if grad_wrapper.is_sparse():
915  return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
916  assert grad_wrapper.is_dense()
917  return grad_wrapper.dense
918 
919  g_input = [to_untyped(grad_wrapper) for grad_wrapper in g_input]
920  grad_defs = []
921  for grad_def_str in grad_defs_str:
922  grad_def = caffe2_pb2.OperatorDef()
923  grad_def.ParseFromString(grad_def_str)
924  grad_defs.append(grad_def)
925  return grad_defs, g_input
926 
927  @classmethod
928  def GetGradientForOp(cls, op, g_output):
929  try:
930  gradient_ops, g_input = cls._GetGradientForOpCC(op, g_output)
931  except Exception as e:
932  # Not supported in C++; will try python registration next.
933 
934  try:
935  gradient_ops, g_input = cls.gradient_registry_[op.type](
936  op, g_output)
937  except KeyError:
938  raise Exception(
939  "No gradient registered for {}. ".format(op.type) +
940  "Exception from creating the gradient op: {}.".format(e))
941 
942  if gradient_ops is None:
943  return [], g_input
944  if type(gradient_ops) is not list:
945  gradient_ops = [gradient_ops]
946  return gradient_ops, g_input
947 
948  @classmethod
949  def GetBackwardPass(cls, operators, ys):
950  """Gets the backward pass for the list of operators.
951 
952  Args:
953  operators: a list of operators constituting the forward pass.
954  ys: a list or a dictionary specifying what blobs we want to compute
955  derivatives of. If the input is a list, we will automatically
956  generate their gradients with all-one values; if the input is a
957  dictionary, for any dictionary entries that are not None, we'll
958  take the corresponding blobs as their gradients; for all those
959  that are None, we will auto-fill them with 1.
960  Returns:
961  gradient_ops: a list of gradient operators to run.
962  all_input_to_grads: a map from input to their corresponding
963  gradients.
964  """
965  ir = IR(operators)
966  return ir.GetBackwardPass(ys)
967 
968 
969 def get_ssa(net, blob_versions=None):
970  """
971  Given a net, return a structure containing the version of each input and
972  output blob used by each operator.
973 
974  Args:
975  net: either a Net or a NetDef
976  blob_versions: (optional) map with current version number for given
977  blob names. If not provided or blob not found, start
978  from version 0.
979  Returns:
980  Tuple (ssa, blob_versions)
981  ssa: list of tuples (versioned_inputs, versioned_outputs)
982  for each op in the net. A versioned input is a tuple
983  (blob_name, version).
984  blob_versions: updated map with latest version of each blob found in
985  the net.
986  """
987  proto = net.Proto() if isinstance(net, Net) else net
988  assert isinstance(proto, caffe2_pb2.NetDef)
989  if blob_versions is None:
990  blob_versions = {}
991  if isinstance(net, list):
992  return [get_ssa(n, blob_versions) for n in net], blob_versions
993  for i in proto.external_input:
994  if i not in blob_versions:
995  blob_versions[str(i)] = 0
996  ssa = []
997  for op in proto.op:
998  if not proto.external_input:
999  for i in op.input:
1000  if i not in blob_versions:
1001  blob_versions[i] = 0
1002  inputs = [(str(i), blob_versions.get(str(i), 0)) for i in op.input]
1003  for o in op.output:
1004  blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1005  outputs = [(str(o), blob_versions[str(o)]) for o in op.output]
1006  ssa.append((inputs, outputs))
1007  return ssa, blob_versions
1008 
1009 
1011  """
1012  Given a ssa in the format produced by get_ssa(), return a set of blobs that
1013  are used before they are defined, which corresponds to inputs at version 0.
1014  """
1015  undef_blobs = set()
1016  for inputs, _outputs in ssa:
1017  undef_blobs |= set(name for (name, ver) in inputs if ver == 0)
1018  return undef_blobs
1019 
1020 
1022  """
1023  Given a ssa in the format produced by get_ssa(), returns a map from
1024  versioned blob into the operator index that produces that version of
1025  the blob. A versioned blob is a tuple (blob_name, version).
1026  """
1027  producers = {}
1028  for i, (_inputs, outputs) in enumerate(ssa):
1029  for o in outputs:
1030  producers[o] = i
1031  return producers
1032 
1033 
1034 def get_op_ids_in_path(ssa, blob_versions, inputs, outputs):
1035  """
1036  Given a ssa and blob_versions as produced by get_ssa(), returns the list
1037  of op indices that are necessary in order to generate the blobs in
1038  `outputs`, given blobs in `inputs`.
1039  Consider that the `inputs` are given in their latest version.
1040  """
1041  inputs_set = set((str(i), blob_versions[str(i)]) for i in inputs)
1042  producers = get_output_producers(ssa)
1043  queue = [(str(o), blob_versions[str(o)]) for o in outputs]
1044  used_op_ids = set()
1045  while len(queue) > 0:
1046  o = queue.pop()
1047  if (o not in inputs_set) and (o in producers):
1048  op_id = producers[o]
1049  if op_id not in used_op_ids:
1050  used_op_ids |= {op_id}
1051  inputs, _ = ssa[op_id]
1052  queue.extend(inputs)
1053  return sorted(used_op_ids)
1054 
1055 
1056 def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None,
1057  keep_schema=True):
1058  """
1059  Clone the given Net, binding its input schema to the given `inputs` record.
1060  Blob names defined by the net are prepended with the given `prefix`.
1061 
1062  Args:
1063  net: the net to clone
1064  name: the name of the new net
1065  prefix: the prefix to append to local blobs
1066  blob_remap: (optional) dict with additional blob name remapping.
1067  inputs: (optional) input record that will provide actual input
1068  values for the cloned net. Must be compatible with the
1069  net's input schema or be a strict superset of it
1070  keep_schema: by default (True), the original schema will be kept and
1071  remapped accordingly. otherwise, the schema will be set as
1072  inputs or left empty if inputs is not given.
1073  Returns:
1074  Tuple (cloned_net, blob_remap)
1075  clone_net: the cloned Net
1076  blob_remap: a map from original blob names into remapped blob names
1077  """
1078  from caffe2.python import schema
1079  assert isinstance(net, Net)
1080  if blob_remap is None:
1081  blob_remap = {}
1082  if inputs is not None:
1083  assert isinstance(inputs, schema.Field)
1084  original = net.input_record()
1085  assert original is not None
1086  # TODO(azzolini): improve schema type checking
1087  diff = set(original.field_names()) - set(inputs.field_names())
1088  assert len(diff) == 0, \
1089  "Schemas don't match, extra fields {} found in the net".format(diff)
1090  original_mapping = dict(zip(original.field_names(),
1091  original.field_blobs()))
1092  for fn, fb in zip(inputs.field_names(), inputs.field_blobs()):
1093  if fn in original_mapping:
1094  blob_remap[str(original_mapping[fn])] = str(fb)
1095  proto = net.Proto()
1096  ssa, blob_versions = get_ssa(proto)
1097  undef_blobs = get_undefined_blobs(ssa)
1098 
1099  for blob in blob_versions.keys():
1100  if blob in blob_remap:
1101  continue
1102  elif blob in undef_blobs:
1103  blob_remap[blob] = blob
1104  else:
1105  blob_remap[blob] = prefix + blob
1106  cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1107  if not keep_schema and inputs:
1108  cloned_net.set_input_record(inputs)
1109  return cloned_net, blob_remap
1110 
1111 
1112 def _get_blob_ref(blob_name_or_ref):
1113  return (
1114  blob_name_or_ref if isinstance(input, BlobReference)
1115  else BlobReference(blob_name_or_ref)
1116  )
1117 
1118 
1119 class Net(object):
1120  _net_names_used = set()
1121  operator_registry_ = {}
1122 
1123  @staticmethod
1124  def current_prefix():
1125  from caffe2.python.net_builder import NetBuilder
1126  builder = NetBuilder.current(required=False)
1127  return builder.name if builder else ''
1128 
1129  @staticmethod
1130  def _get_next_net_name(basename):
1131  name = basename = '/'.join(filter(
1132  lambda x: x, (Net.current_prefix(), basename)))
1133  next_idx = 1
1134  while name in Net._net_names_used:
1135  name = basename + '_' + str(next_idx)
1136  next_idx += 1
1137  Net._net_names_used |= set([name])
1138  return name
1139 
1140  def __init__(self, name_or_proto):
1141  """
1142  Create a Net.
1143  Args:
1144  name_or_proto: If a NetDef is provided, clone it. Otherwise,
1145  create an empty net with the given name.
1146  """
1147  self._input_record = None
1148  self._output_record = None
1149  # Register blobs so that it's guaranteed that different calls to
1150  # NextBlob/NextScopedBlob always return blobs with different names
1151  self._registered_blob_names = set()
1152  self._recreate_lookup_tables = False
1153  self._op_outputs = set()
1154  self._external_input_map = set()
1155  self._attr_dict = defaultdict(list)
1156  if type(name_or_proto) is caffe2_pb2.NetDef:
1157  proto = name_or_proto
1158  # We rae initializing a network by a NetDef. In this case, we will
1159  # initialize our network with the given netdef.
1160  self._net = caffe2_pb2.NetDef()
1161  self._net.CopyFrom(proto)
1162 
1163  existing_outputs = [list(op.output) for op in self._net.op]
1164 
1165  self._external_input_map.update(list(self._net.external_input))
1166 
1167  # Set the next name index properly.
1168  existing_names = set(
1169  sum(
1170  [list(op.input) for op in self._net.op], []
1171  ) + sum(
1172  existing_outputs, []
1173  )
1174  )
1175  for outs in existing_outputs:
1176  self._op_outputs.update(outs)
1177 
1178  prefix_len = len(self._net.name + '_blob_')
1179  autogen_indices = []
1180  for s in existing_names:
1181  if s.startswith(self._net.name + '_blob_'):
1182  try:
1183  autogen_indices.append(int(s[prefix_len]))
1184  except ValueError:
1185  pass
1186  if len(autogen_indices):
1187  self._next_name_index = max(autogen_indices) + 1
1188  else:
1189  self._next_name_index = 0
1190  name = self._net.name
1191  else:
1192  name = name_or_proto
1193  self._net = caffe2_pb2.NetDef()
1194  self._next_name_index = 0
1195 
1196  # make sure that this net name hasn't been used before
1197  self._net.name = Net._get_next_net_name(name)
1198 
1199  def AppendNet(self, net):
1200  assert isinstance(net, Net)
1201  self._ExtendOps(net.Proto().op)
1202  self.Proto().external_input.extend(
1203  [i for i in net.Proto().external_input
1204  if i not in self.Proto().external_input])
1205  self.Proto().external_output.extend(
1206  [o for o in net.Proto().external_output
1207  if o not in self.Proto().external_output])
1208  return self
1209 
1210  def LogInfo(self, *msg_or_blobs):
1211  for msg_or_blob in msg_or_blobs:
1212  if not isinstance(msg_or_blob, BlobReference):
1213  blob = self.GivenTensorStringFill(
1214  [], self.NextName('log'),
1215  shape=[], values=[msg_or_blob])
1216  else:
1217  blob = msg_or_blob
1218  self.Print(blob, [])
1219 
1220  def add_attribute(self, name, obj):
1221  """
1222  Add `obj` to the list of attributes in this net under the given `name`.
1223  Attributes are user-defined objects and have no pre-defined semantics.
1224  """
1225  self._attr_dict[name].append(obj)
1226 
1227  def get_attributes(self, name):
1228  """
1229  Returns the list of attributes in this net for a given `name`.
1230  Attributes are user-defined objects added with `add_attribute'.
1231  """
1232  return self._attr_dict.get(name, [])
1233 
1234  def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False):
1235  """
1236  Adds a random seed to each op in the net.
1237  If sequence_seed is set, the i-th op has rand_seed=`seed + i`
1238  If seed_on_op_def is set, the op rand_seed=hash(str(op))
1239  sequence_seed and seed_on_op_def cannot be both set to True.
1240  """
1241  assert not (sequence_seed and seed_on_op_def), (
1242  'sequence_seed and seed_on_op_def cannot be both set to True.')
1243  for i, op in enumerate(self.Proto().op):
1244  if sequence_seed:
1245  curr_seed = seed + i
1246  elif seed_on_op_def:
1247  curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1248  else:
1249  curr_seed = seed
1250  op.device_option.random_seed = curr_seed
1251 
1252  def Name(self):
1253  return self._net.name
1254 
1255  def __str__(self):
1256  return self.Name()
1257 
1258  def Const(self, array, blob_out=None, dtype=None):
1259  if isinstance(array, bool):
1260  return self.ConstantFill(
1261  [],
1262  blob_out or 1,
1263  dtype=DataType.BOOL,
1264  value=array)
1265 
1266  if dtype is None:
1267  array = np.array(array)
1268  else:
1269  array = np.array(array, dtype=dtype)
1270 
1271  def do_set(operator):
1272  return operator(
1273  [],
1274  blob_out or 1,
1275  shape=array.shape,
1276  values=array.flatten().tolist())
1277 
1278  if array.dtype == np.int32:
1279  return do_set(self.GivenTensorIntFill)
1280  elif array.dtype == np.int64:
1281  return do_set(self.GivenTensorInt64Fill)
1282  elif array.dtype == np.str:
1283  return do_set(self.GivenTensorStringFill)
1284  else:
1285  return do_set(self.GivenTensorFill)
1286 
1287  def BlobIsDefined(self, blob):
1288  """
1289  Returns true if the given BlobReference is produced as output of
1290  an operator in this net, or if it is provided as an external input.
1291  """
1292  if self._recreate_lookup_tables:
1293  self._RecreateLookupTables()
1294  name = str(blob)
1295  return (name in self._op_outputs) or (name in self._external_input_map)
1296 
1297  def UsesBlob(self, blob):
1298  """
1299  Returns true iff the given BlobReference is used by any operator
1300  or this net, or if it is one of the external inputs of the net.
1301  """
1302  blob_name = str(blob)
1303  for op in self._net.op:
1304  for input in op.input:
1305  if input == blob_name:
1306  return True
1307  return blob_name in self._external_input_map
1308 
1309  def GetBlobRef(self, blob_name):
1310  """
1311  Given the name of a blob produced by this net, return a BlobReference
1312  to it. If the blob is not produced by any op in this net,
1313  raises KeyError.
1314  """
1315  blob_name = str(blob_name)
1316  if not self.BlobIsDefined(blob_name):
1317  raise KeyError('Net does not define blob %s' % blob_name)
1318  return BlobReference(blob_name, self)
1319 
1320  def Clone(
1321  self,
1322  name,
1323  blob_remap=None,
1324  op_id_mask=None,
1325  remap_funcs=None,
1326  keep_schema=True
1327  ):
1328  """
1329  Clone this net.
1330  Args:
1331  name: name of the cloned net
1332  blob_remap: optional map with list of blob names to replace
1333  op_id_mask: optional list of operator indices to include in
1334  the cloned net. If not provided, all ops are included.
1335  """
1336  if remap_funcs is None:
1337  remap_funcs = {}
1338  proto = self._net
1339  new_proto = caffe2_pb2.NetDef()
1340  new_proto.CopyFrom(proto)
1341  new_proto.name = name
1342 
1343  if blob_remap is None:
1344  blob_remap = {}
1345  if op_id_mask is None:
1346  op_id_mask = range(0, len(proto.op))
1347 
1348  def get_remapped_str(blob):
1349  blob_str = str(blob)
1350  return str(blob_remap.get(blob_str, blob_str))
1351 
1352  def remap_list(proto_list):
1353  new_list = [get_remapped_str(b) for b in proto_list]
1354  del proto_list[:]
1355  proto_list.extend(new_list)
1356 
1357  def remap_op(op):
1358  new_op = caffe2_pb2.OperatorDef()
1359  new_op.CopyFrom(op)
1360  remap_list(new_op.input)
1361  remap_list(new_op.output)
1362  if new_op.type in remap_funcs:
1363  remap_funcs[new_op.type](new_op, (name + '/') if name else '')
1364  return new_op
1365 
1366  del new_proto.op[:]
1367  new_proto.op.extend([remap_op(proto.op[op_id]) for op_id in op_id_mask])
1368  remap_list(new_proto.external_input)
1369  remap_list(new_proto.external_output)
1370  new_net = Net(new_proto)
1371 
1372  if keep_schema:
1373  from caffe2.python import schema
1374  if self._input_record:
1375  new_net._input_record = schema.from_blob_list(
1376  self._input_record,
1377  [
1378  BlobReference(get_remapped_str(blob), net=new_net)
1379  for blob in self._input_record.field_blobs()
1380  ],
1381  )
1382  if self._output_record:
1383  new_net._output_record = schema.from_blob_list(
1384  self._output_record,
1385  [
1386  BlobReference(get_remapped_str(blob), net=new_net)
1387  for blob in self._output_record.field_blobs()
1388  ],
1389  )
1390 
1391  new_net._attr_dict.update(self._attr_dict)
1392  return new_net
1393 
1394  def ClonePartial(self, name, inputs, outputs, remap_funcs=None):
1395  """
1396  Clone this net, including only ops that are necessary in order to
1397  compute `outputs` given `inputs`. Return references to the cloned
1398  outputs. Internal blobs (blobs that are produced and consumed inside
1399  the net but not used as outputs) will be remapped to avoid name
1400  conflict.
1401 
1402  Args:
1403  name: the name of the cloned net
1404  inputs: map where the keys correspond to BlobReferences in the
1405  original net, and the values correspond to external inputs
1406  in the partially cloned net. If `inputs` is a list, don't
1407  remap input names.
1408  outputs: outputs to be produced by the cloned net.
1409 
1410  Returns:
1411  Tuple (new_net, new_outputs)
1412  new_net: a new Net object.
1413  new_outputs: list of BlobReferences corresponding to the
1414  outputs produced by new_net.
1415  """
1416  input_is_pair_list = isinstance(inputs, list) and all(
1417  isinstance(i, tuple) and len(i) == 2 for i in inputs)
1418  inputs = (
1419  inputs if isinstance(inputs, (dict, OrderedDict)) else
1420  OrderedDict(inputs) if input_is_pair_list else
1421  OrderedDict(zip(inputs, inputs)))
1422  for output in outputs:
1423  assert self.BlobIsDefined(output)
1424  input_names = {str(k): str(v) for k, v in inputs.items()}
1425  output_names = [str(o) for o in outputs]
1426  proto = self._net
1427  ssa, blob_versions = get_ssa(proto)
1428  used_op_ids = get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
1429  disallowed_op_ids = get_op_ids_in_path(ssa, blob_versions, [], inputs)
1430  assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1431  'Cannot partially clone net: some of the ops required would ' +
1432  'generate the given input.')
1433 
1434  sub_ssa = [op for i, op in enumerate(ssa) if i in used_op_ids]
1435  undef_blobs = get_undefined_blobs(sub_ssa) - set(input_names.keys())
1436  prefix = (name + '/') if name else ''
1437 
1438  def remap(blob_name):
1439  if blob_name in input_names:
1440  return input_names[blob_name]
1441  elif blob_name in undef_blobs:
1442  return blob_name
1443  else:
1444  return prefix + blob_name
1445 
1446  blob_mapping = {b: remap(b) for b in blob_versions.keys()}
1447  new_net = self.Clone(name, blob_mapping, used_op_ids, remap_funcs)
1448  new_in = [
1449  blob_mapping[i] for i in input_names.keys()] + list(undef_blobs)
1450  new_out = [blob_mapping[o] for o in output_names]
1451  del new_net.Proto().external_input[:]
1452  new_net.Proto().external_input.extend(new_in)
1453  new_net._external_input_map = set(list(new_in))
1454  del new_net.Proto().external_output[:]
1455  new_net.Proto().external_output.extend(new_out)
1456  return new_net, [new_net.GetBlobRef(o) for o in new_out]
1457 
1458  def Proto(self):
1460  return self._net
1461 
1462  def NextScopedBlob(self, prefix='unnamed'):
1463  """Return the blob that has not been defined or registered in the
1464  current net. It returns `ScopedBlobReference(prefix)`, if it's valid,
1465  otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls
1466  is guaranteed to return blob with different names.
1467  """
1468  output_blob_base = ScopedName(prefix)
1469  return self.NextBlob(output_blob_base)
1470 
1471  def NextBlob(self, prefix='unnamed'):
1472  """Return the blob that has not been defined or registered in the
1473  current net. It returns `BlobReference(prefix)`, if it's valid,
1474  otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls
1475  is guaranteed to return blob with different names."""
1476  output_blob_base = BlobReference(prefix)
1477  output_blob = output_blob_base
1478  index = 0
1479  while str(output_blob) in self._registered_blob_names or (
1480  self.BlobIsDefined(output_blob)):
1481  output_blob = output_blob_base + '_auto_' + str(index)
1482  index += 1
1483 
1484  self._registered_blob_names.add(str(output_blob))
1485  return output_blob
1486 
1487  def NextName(self, prefix=None, output_id=None):
1488  """Returns the next name to be used, if you do not want to explicitly
1489  name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]"""
1490  if prefix:
1491  output_name_base = self._net.name + '/' + prefix
1492  output_name = output_name_base
1493  if output_id is not None:
1494  output_name += ':' + str(output_id)
1495  index = 2
1496  while self.BlobIsDefined(str(ScopedBlobReference(output_name))):
1497  output_name = output_name_base + '_' + str(index)
1498  if output_id is not None:
1499  output_name += ':' + str(output_id)
1500  index += 1
1501  else:
1502  output_name = self._net.name + '_blob_' + str(self._next_name_index)
1503  self._next_name_index += 1
1504  return str(output_name)
1505 
1506  def _ExtendOps(self, new_ops):
1507  self._net.op.extend(new_ops)
1508  for op in new_ops:
1509  self._op_outputs.update([str(o) for o in op.output])
1510 
1511  def _CheckLookupTables(self):
1512  '''
1513  Called from unit tests to validate the internal lookup tables
1514  match the protobuf contents.
1515  '''
1516  test_op_outputs = set()
1517  for op in self._net.op:
1518  for o in op.output:
1519  test_op_outputs.add(o)
1520 
1521  test_external_inp = set()
1522  for inp in self._net.external_input:
1523  test_external_inp.add(inp)
1524 
1525  assert test_op_outputs.difference(self._op_outputs) == set()
1526  assert test_external_inp.difference(self._external_input_map) == set()
1527 
1528  def _InvalidateLookupTables(self):
1529  self._recreate_lookup_tables = True
1530 
1531  def _RecreateLookupTables(self):
1532  self._op_outputs = set()
1533  for op in self._net.op:
1534  for o in op.output:
1535  self._op_outputs.add(o)
1536 
1537  self._external_input_map = set()
1538  for inp in self._net.external_input:
1539  self._external_input_map.add(inp)
1540 
1541  self._recreate_lookup_tables = False
1542 
1543  def AddGradientOperators(self, ys, skip=0):
1544  """Add the gradient for operators in the net.
1545 
1546  Inputs:
1547  ys: a list or a dictionary specifying what blobs we want to compute
1548  derivatives of. If the input is a list, we will automatically
1549  generate their gradients with all-one values; if the input is a
1550  dictionary, for any dictionary entries that are not None, we will
1551  take the corresponding blobs as their gradients; for all those
1552  that are None, we will auto-fill them with 1.
1553  skip: skips the first n operators. This is provided mainly because a
1554  lot of nets may use the first few operators for data generation
1555  like stuff which really do not need to have gradients.
1556 
1557  Outputs:
1558  returns a map from the blob name in the input network to a blob
1559  containing gradient or a GradientSlice in case of sparse gradient
1560 
1561  Currently, this is hard-coded for float operators if there are branches
1562  (i.e. a blob is used as input to multiple operators). This is because
1563  the gradient accumulation (Sum) is float only right now.
1564  """
1565 
1566  grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
1567  self._net.op[skip:], ys)
1568  # Check if in immediate mode: the grad_ops are actually being produced
1569  # by C++ and bypasses the CreateOperator() call, so in immediate mode
1570  # we will have to explicitly run them.
1571  if workspace.IsImmediate():
1572  for op in grad_ops:
1574  self._ExtendOps(grad_ops)
1575  return input_to_grad
1576 
1577  def AddExternalInput(self, *inputs):
1578  assert len(inputs) > 0
1579  refs = []
1580  for input in inputs:
1581  input_name = str(input)
1582  assert str(input) not in self._external_input_map, (
1583  'Net already contains an input named %s' % input_name)
1584  for input in inputs:
1585  input_name = str(input)
1586  self._net.external_input.extend([input_name])
1587  self._external_input_map.update([input_name])
1588  refs.append(_get_blob_ref(input_name))
1589 
1590  return refs[0] if len(refs) == 1 else refs
1591 
1592  def AddExternalOutput(self, *outputs):
1593  for output in outputs:
1594  assert isinstance(output, BlobReference)
1595  assert self.BlobIsDefined(output)
1596  for output in outputs:
1597  self.Proto().external_output.extend([str(output)])
1598 
1599  def AddScopedExternalInputs(self, *inputs):
1600  return self.AddExternalInput(
1601  * [ScopedBlobReference(str(b)) for b in inputs]
1602  )
1603 
1604  def AddScopedExternalOutputs(self, *outputs):
1605  return self.AddExternalOutput(
1606  * [ScopedBlobReference(str(b)) for b in outputs]
1607  )
1608 
1609  @property
1610  def external_inputs(self):
1611  return map(_get_blob_ref, self._net.external_input)
1612 
1613  @property
1614  def external_outputs(self):
1615  return map(_get_blob_ref, self._net.external_output)
1616 
1617  def set_input_record(self, input_record):
1618  from caffe2.python import schema
1619  assert self._input_record is None, (
1620  'Input schema cannot be reset')
1621  if not input_record.has_blobs():
1622  with NameScope(self.Name()):
1623  self._input_record = schema.NewRecord(self, input_record)
1624  else:
1625  self._input_record = input_record
1626  for blob in input_record.field_blobs():
1627  if blob not in self.external_inputs:
1628  self.AddExternalInput(blob)
1629  return self._input_record
1630 
1631  def set_output_record(self, record):
1632  assert self._output_record is None, (
1633  'Output record cannot be reset')
1634  for blob in record.field_blobs():
1635  assert self.BlobIsDefined(blob)
1636  for blob in record.field_blobs():
1637  self.AddExternalOutput(blob)
1638  self._output_record = record
1639 
1640  def AppendOutputRecordField(self, field_name, record):
1641  from caffe2.python import schema
1642  assert self._output_record is not None, (
1643  'Tried to append to missing output record'
1644  )
1645  for blob in record.field_blobs():
1646  assert self.BlobIsDefined(blob)
1647  for blob in record.field_blobs():
1648  self.AddExternalOutput(blob)
1650  (field_name, record)
1651  )
1652 
1653  def input_record(self):
1654  return self._input_record
1655 
1656  def output_record(self):
1657  return self._output_record
1658 
1659  def AddExternalInputs(self, *inputs):
1660  return self.AddExternalInput(*inputs)
1661 
1662  def AddExternalOutputs(self, *outputs):
1663  self.AddExternalOutput(*outputs)
1664 
1665  def DeduplicateGradientSlices(self, g, aggregator='sum'):
1666  assert isinstance(g, GradientSlice)
1667  unique, remapping = self.Unique([g.indices], 2, engine='SparseHash')
1668  if aggregator.lower() == 'sum':
1669  new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
1670  elif aggregator.lower() == 'mean':
1671  new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
1672  else:
1673  raise ValueError('{} is not supported'.format(aggregator))
1674  return GradientSlice(indices=unique, values=new_g)
1675 
1676  def RunAllOnGPU(self, gpu_id=0, use_cudnn=False):
1677  """A convenient function to run everything on the GPU."""
1678  device_option = caffe2_pb2.DeviceOption()
1679  device_option.device_type = caffe2_pb2.CUDA
1680  device_option.cuda_gpu_id = gpu_id
1681  self._net.device_option.CopyFrom(device_option)
1682  if use_cudnn:
1683  for op in self._net.op:
1684  op.engine = "CUDNN"
1685 
1686  def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
1687  """A helper function to create an operator and add it to self.
1688  """
1689  inputs = _RectifyInputOutput(inputs)
1690  for input in inputs:
1691  if not self.BlobIsDefined(input):
1692  assert input.Net() != self
1693  self.AddExternalInput(input)
1694  if outputs is None:
1695  # If we do not specify an output, we will assume that this op
1696  # produces one output in this case.
1697  outputs = self.NextName(prefix=op_type)
1698  elif type(outputs) is int:
1699  # In this case, we will auto-fill the given number of outputs
1700  # with auto-generated names.
1701  outputs = [
1702  self.NextName(prefix=op_type, output_id=i)
1703  for i in range(outputs)]
1704  outputs = _RectifyInputOutput(outputs, net=self)
1705  op = CreateOperator(op_type, inputs, outputs, **kwargs)
1706  self._ExtendOps([op])
1707  if len(op.output) == 0:
1708  return
1709  elif len(op.output) == 1:
1710  return BlobReference(str(op.output[0]), self)
1711  else:
1712  return tuple(BlobReference(str(o), self) for o in op.output)
1713 
1714  def __getattr__(self, op_type):
1715  if op_type.startswith('__'):
1716  raise AttributeError('Attribute {} not found.'.format(op_type))
1717  if not IsOperator(op_type) and not IsOperatorWithEngine(op_type, "CUDNN"):
1718  raise RuntimeError(
1719  'Method ' + op_type + ' is not a registered operator.' +
1720  ' Did you mean: [' +
1721  ",".join(workspace.C.nearby_opnames(op_type)) + ']'
1722  )
1723  return lambda *args, **kwargs: self._CreateAndAddToSelf(
1724  op_type, *args, **kwargs)
1725 
1726  def Python(self, f, grad_f=None, pass_workspace=False):
1727  """
1728  Registers and returns a python operator.
1729 
1730  `f` and `f_grad` can be one of the following:
1731  - a function with signature (inputs, outputs), where inputs and
1732  outputs are a list of CPUTensor objects. This function will be
1733  called from C++ everytime the operator is executed.
1734  - a tuple (func, args, kwargs), here `func` is a callable, args is
1735  an argument list, and kwargs is a dict list. The call:
1736  f = func(*args, kwargs)
1737  will be performed locally at node initialization time, on all of
1738  the nodes of the job, returning `f`, a callable that will be used
1739  as the python operator function to be called during Net execution.
1740  This is to be used when using python operator in a distributed
1741  context, and allows to create and keep local python state across
1742  calls to the operator.
1743 
1744  If `pass_workspace` is True, the signature is changed to
1745  (inputs, outputs, workspace) where `workspace` is the workspace the op
1746  is going to run on. This is potentially dangerous (as the op can
1747  manipulate the workspace directly), use on your own risk.
1748  """
1749  assert(IsOperator('Python'))
1750  if isinstance(f, tuple) or isinstance(grad_f, tuple):
1751  # if we got a tuple, we will make sure this tuple will be
1752  # registered to run at startup on each of the workers in a
1753  # distributed run.
1754  registry = worker_init_func(_RegisterPythonImpl)
1755  else:
1756  registry = _RegisterPythonImpl
1757  token = registry(f, grad_f, pass_workspace=pass_workspace)
1758  return lambda *args, **kwargs: self._CreateAndAddToSelf(
1759  'Python', token=token, *args, **kwargs)
1760 
1761 
1762 def get_net_name(netlike):
1763  if isinstance(netlike, Net):
1764  return netlike.Proto().name
1765  elif isinstance(netlike, caffe2_pb2.NetDef):
1766  return netlike.name
1767  else:
1768  return netlike
1769 
1770 
1771 def output_to_list(op_output):
1772  """
1773  Ensures that the output of an operator is a list.
1774  Use when an operator has a variable number of outputs, but a list of
1775  outputs is desired even when number of outputs is 1.
1776 
1777  Args:
1778  op_output: Either a BlobReferenece or an iterable of BlobReferences.
1779 
1780  Returns:
1781  A list of BlobReferences.
1782  """
1783  assert type(op_output) in (list, tuple, BlobReference)
1784  return (
1785  [op_output]
1786  if isinstance(op_output, BlobReference) else list(op_output))
1787 
1788 
1789 def _add_net_to_dict(net_dict, net):
1790  name = get_net_name(net)
1791  if name in net_dict:
1792  assert net_dict[name] is None or net == net_dict[name], (
1793  'Different nets with same name: ' + name)
1794  return False
1795  else:
1796  net_dict[name] = net if isinstance(net, Net) else None
1797  return True
1798 
1799 
1800 class ExecutionStep(object):
1801  _step_names_used = set()
1802 
1803  @staticmethod
1804  def _get_next_step_name(basename):
1805  name = basename
1806  next_idx = 1
1807  while name in ExecutionStep._step_names_used:
1808  name = basename + '_' + str(next_idx)
1809  next_idx += 1
1810  ExecutionStep._step_names_used |= set([name])
1811  return name
1812 
1813  def __init__(self, name, nets=None, num_iter=None):
1814  self._step = caffe2_pb2.ExecutionStep()
1815  self._step.name = name or ExecutionStep._get_next_step_name('step')
1816  self._net_dict = OrderedDict()
1817  self._is_used = False
1818  self._substeps = []
1819  if nets is not None:
1820  if type(nets) is Net:
1821  nets = [nets]
1822  for net in nets:
1823  if _add_net_to_dict(self._net_dict, net):
1824  self._step.network.extend([get_net_name(net)])
1825  if num_iter is not None:
1826  self._step.num_iter = num_iter
1827 
1828  def get_net(self, name):
1829  return self._net_dict[name]
1830 
1831  def Name(self):
1832  return self._step.name
1833 
1834  def __str__(self):
1835  return self._step.name
1836 
1837  def _assert_can_mutate(self):
1838  assert not self._is_used, (
1839  'Cannot mutate a step that has already been added to a plan/step.')
1840 
1841  def _notify_is_used(self):
1842  self._is_used = True
1843 
1844  def Proto(self):
1845  return self._step
1846 
1847  def HasNets(self):
1848  return self._step.network is not None and (
1849  len(self._step.network) > 0)
1850 
1851  def HasSubsteps(self):
1852  return self._step.substep is not None and (
1853  len(self._step.substep) > 0)
1854 
1855  def Nets(self):
1856  return self._net_dict.values()
1857 
1858  def Substeps(self):
1859  return self._substeps
1860 
1861  def SetIter(self, num_iter):
1862  self._assert_can_mutate()
1863  self._step.num_iter = num_iter
1864 
1865  def SetOnlyOnce(self, only_once):
1866  self._assert_can_mutate()
1867  self._step.only_once = only_once
1868 
1869  def SetShouldStopBlob(self, should_stop_blob):
1870  assert isinstance(should_stop_blob, BlobReference), (
1871  "expects BlobReference here, got {}".format(type(should_stop_blob)))
1872  self._assert_can_mutate()
1873  self._step.should_stop_blob = str(should_stop_blob)
1874 
1875  def RunEveryMillis(self, interval):
1876  """
1877  Run this step every interval millisecods, as long as its
1878  siblings are still running. It is guaranteed that, after all
1879  siblings finish, this step will run at least one.
1880 
1881  This property is ignored for top-level ExecutionSteps.
1882  """
1883  self._step.run_every_ms = interval
1884 
1885  def SetReportNet(self, report_net, report_interval):
1886  """ DEPRECATED. Use RunEveryMillis instead. """
1887  self._assert_can_mutate()
1888  _add_net_to_dict(self._net_dict, report_net)
1889  self._step.report_net = get_net_name(report_net)
1890  self._step.report_interval = report_interval
1891 
1892  def AddSubstep(self, substep):
1893  self._assert_can_mutate()
1894  assert not self.HasNets(), 'Cannot have both network and substeps.'
1895  if isinstance(substep, ExecutionStep):
1896  substep._notify_is_used()
1897  if not substep.HasNets() and not substep.HasSubsteps():
1898  return self
1899  for net in substep.Nets():
1900  _add_net_to_dict(self._net_dict, net)
1901  self._substeps.append(substep)
1902  proto = substep.Proto()
1903  else:
1904  proto = substep
1905  self._step.substep.add().CopyFrom(proto)
1906  return self
1907 
1908  def SetConcurrentSubsteps(self, concurrent_substeps):
1909  self._assert_can_mutate()
1910  assert not self.HasNets(), 'Cannot have both network and substeps.'
1911  self._step.concurrent_substeps = concurrent_substeps
1912 
1913  def AddNet(self, net):
1914  self._assert_can_mutate()
1915  assert not self.HasSubsteps(), 'Cannot have both network and substeps.'
1916  assert isinstance(net, Net)
1917  _add_net_to_dict(self._net_dict, net)
1918  self._step.network.extend([get_net_name(net)])
1919  return self
1920 
1921  def get_all_attributes(self, name):
1922  """
1923  Return the list of all attributes under the given `name`, present in
1924  all of the nets used in this execution step and its children.
1925  """
1926  objs = []
1927  for net in self._net_dict.values():
1928  objs += net.get_attributes(name)
1929  return objs
1930 
1931 
1932 def add_nets_in_order(step, net_list):
1933  proto = step.Proto()
1934  for substep in step.Substeps():
1935  add_nets_in_order(substep, net_list)
1936  for net in proto.network:
1937  if net not in net_list:
1938  net_list.append(net)
1939  # FIXME(azzolini): This is actually wrong. Report nets should be
1940  # instantiated first since they may run before any substep is run.
1941  # However, curerntly, Reporter depends on this behavior.
1942  if proto.report_net and proto.report_net not in net_list:
1943  net_list.append(proto.report_net)
1944 
1945 
1946 class Plan(object):
1947 
1948  def __init__(self, name_or_step):
1949  self._plan = caffe2_pb2.PlanDef()
1950  self._net_dict = OrderedDict()
1951  if isinstance(name_or_step, ExecutionStep):
1952  self._plan.name = name_or_step.Name()
1953  self.AddStep(name_or_step)
1954  elif isinstance(name_or_step, basestring):
1955  self._plan.name = name_or_step
1956  else:
1957  raise ValueError('name_or_step must be a string or ExecutionStep')
1958 
1959  def __str__(self):
1960  return self._plan.name
1961 
1962  def Proto(self):
1963  return self._plan
1964 
1965  def AddNets(self, nets):
1966  for net in nets:
1967  if _add_net_to_dict(self._net_dict, net):
1968  assert isinstance(net, Net)
1969  self._plan.network.add().CopyFrom(net.Proto())
1970 
1971  def Nets(self):
1972  return self._net_dict.values()
1973 
1974  def AddStep(self, step):
1975  assert isinstance(step, ExecutionStep)
1976  step._notify_is_used()
1977  if not step.HasNets() and not step.HasSubsteps():
1978  return
1979  self._plan.execution_step.add().CopyFrom(step.Proto())
1980  # nets need to be added to the plan in order of usage
1981  net_list = []
1982  add_nets_in_order(step, net_list)
1983  self.AddNets([step.get_net(n) for n in net_list])
1984 
1985  def get_all_attributes(self, name):
1986  """
1987  Return the list of all attributes under the given `name`, present in
1988  all of the nets used in this plan.
1989  """
1990  objs = []
1991  for net in self._net_dict.values():
1992  objs += net.get_attributes(name)
1993  return objs
1994 
1995 
1996 def to_execution_step(step_or_nets, default_name=None):
1997  from caffe2.python.net_builder import NetBuilder
1998  if isinstance(step_or_nets, ExecutionStep):
1999  return step_or_nets
2000 
2001  stop_blob = None
2002  if not default_name and hasattr(step_or_nets, 'name'):
2003  default_name = step_or_nets.name
2004  if isinstance(step_or_nets, NetBuilder):
2005  stop_blob = step_or_nets._stop_blob
2006  step_or_nets = step_or_nets.get()
2007  return execution_step(
2008  default_name, step_or_nets, should_stop_blob=stop_blob)
2009 
2010 
2011 def execution_step(default_name,
2012  steps_or_nets,
2013  num_iter=None,
2014  report_net=None,
2015  report_interval=None,
2016  concurrent_substeps=None,
2017  should_stop_blob=None,
2018  only_once=None):
2019  """
2020  Helper for creating an ExecutionStep.
2021  - steps_or_nets can be:
2022  - None
2023  - Net
2024  - ExecutionStep
2025  - list<Net>
2026  - list<ExecutionStep>
2027  - should_stop_blob is either None or a scalar boolean blob.
2028  - This blob is checked AFTER every substeps/subnets.
2029  - If specified and true, then this step will return immediately.
2030  - Be sure to handle race conditions if setting from concurrent threads.
2031  - if no should_stop_blob or num_iter is provided, defaults to num_iter=1
2032  """
2033  assert should_stop_blob is None or num_iter is None, (
2034  'Cannot set both should_stop_blob and num_iter.')
2035  if should_stop_blob is None and num_iter is None:
2036  num_iter = 1
2037 
2038  step = ExecutionStep(default_name)
2039  if should_stop_blob is not None:
2040  step.SetShouldStopBlob(should_stop_blob)
2041  if num_iter is not None:
2042  step.SetIter(num_iter)
2043  if only_once is not None:
2044  step.SetOnlyOnce(only_once)
2045  if concurrent_substeps is not None:
2046  step.SetConcurrentSubsteps(concurrent_substeps)
2047  if report_net is not None:
2048  assert report_interval is not None
2049  step.SetReportNet(report_net, report_interval)
2050 
2051  if isinstance(steps_or_nets, ExecutionStep):
2052  step.AddSubstep(steps_or_nets)
2053  elif isinstance(steps_or_nets, Net):
2054  step.AddNet(steps_or_nets)
2055  elif isinstance(steps_or_nets, list):
2056  if all(isinstance(x, Net) for x in steps_or_nets):
2057  map(step.AddNet, steps_or_nets)
2058  else:
2059  map(step.AddSubstep, map(to_execution_step, steps_or_nets))
2060  elif steps_or_nets:
2061  raise ValueError(
2062  'steps_or_nets must be a step, a net, or a list of nets or steps.')
2063  return step
2064 
2065 
2066 def scoped_execution_step(name, *args, **kwargs):
2067  """Same as execution_step() except that the step name is scoped."""
2068  default_name = ScopedName(name) if name else name
2069  return execution_step(default_name, *args, **kwargs)
dictionary gradient_registry_
Definition: core.py:874
def GetBackwardPass(cls, operators, ys)
Definition: core.py:949
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
Definition: core.py:2018
def AddExternalInput(self, inputs)
Definition: core.py:1577
def output_to_list(op_output)
Definition: core.py:1771
def _CreateAndAddToNet(self, op_type, inputs=None, args, kwargs)
Definition: core.py:171
def HasNets(self)
Definition: core.py:1847
def CreatePythonOperator(f, inputs, outputs, grad_f=None, pass_workspace=False, args, kwargs)
Definition: core.py:322
input_usages
Definition: core.py:382
def _GetGradientForOpCC(cls, op_def, g_output)
Definition: core.py:887
_net_dict
Definition: core.py:1950
def NewRecord(net, schema)
Definition: schema.py:908
def CurrentDeviceScope()
Definition: scope.py:33
_registered_blob_names
Definition: core.py:1151
_output_record
Definition: core.py:1148
def CurrentNameScope()
Definition: scope.py:26
RegisteredOperators
Definition: workspace.py:28
def external_inputs(self)
Definition: core.py:1610
def ScopedBlobReference(name, args, kwargs)
Definition: core.py:212
def DoGradientAccumulation(self, fwd_op_idx)
Definition: core.py:725
def get_all_attributes(self, name)
Definition: core.py:1921
def _RecreateLookupTables(self)
Definition: core.py:1531
def IsImmediate()
Definition: workspace.py:339
gradient_generators
Definition: core.py:385
Definition: core.py:358
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt)
Definition: core.py:588
def scoped_execution_step(name, args, kwargs)
Definition: core.py:2066
def _MakeSparseSumOps(self, generators, out_base_name)
Definition: core.py:621
def _MakeSumOps(self, input_name, input_version)
Definition: core.py:672
def MakeArgument(key, value)
Definition: utils.py:66
def NextScopedBlob(self, prefix='unnamed')
Definition: core.py:1462
ssa
Definition: core.py:381
def _SetSumOpsDeviceOption(self, sum_ops, generators)
Definition: core.py:576
def BuildGradientGenerators(self, fwd_op_idx, gradient_ops, g_output, g_input)
Definition: core.py:483
def AppendSparseGenerators(self, sparse_generators)
Definition: core.py:459
def _assert_can_mutate(self)
Definition: core.py:1837
_attr_dict
Definition: core.py:1155
_op_outputs
Definition: core.py:1153
def _GenerateGradientsForForwardOp(self, forward_op_idx, input_to_grad)
Definition: core.py:790
def Name(self)
Definition: core.py:1252
def AddExternalOutput(self, outputs)
Definition: core.py:1592
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, kwargs)
Definition: core.py:1686
def CheckGradientOperatorInput(self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs)
Definition: core.py:410
def get_undefined_blobs(ssa)
Definition: core.py:1010
_input_record
Definition: core.py:1147
_next_name_index
Definition: core.py:1187
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False)
Definition: core.py:1676
def get_ssa(net, blob_versions=None)
Definition: core.py:969
def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None, keep_schema=True)
Definition: core.py:1327
def NextBlob(self, prefix='unnamed')
Definition: core.py:1471
def GetBlobRef(self, blob_name)
Definition: core.py:1309
def _VerifyGradientGenerators(self, generator)
Definition: core.py:685
def GetIndexFromGradientList(g_list, name)
Definition: core.py:337
def __getattr__(self, op_type)
Definition: core.py:182
def _ExtendOps(self, new_ops)
Definition: core.py:1506
def _MakeDenseSumOps(self, generators, out_base_name)
Definition: core.py:601
def UsesBlob(self, blob)
Definition: core.py:1297
def NextName(self, prefix=None, output_id=None)
Definition: core.py:1487
def ScopedName(name)
Definition: core.py:207
def AddNets(self, nets)
Definition: core.py:1965
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False)
Definition: core.py:1234
def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None, keep_schema=True)
Definition: core.py:1057
def RunEveryMillis(self, interval)
Definition: core.py:1875
def BlobIsDefined(self, blob)
Definition: core.py:1287
def _GetInitGradients(self, ys)
Definition: core.py:771
def Proto(self)
Definition: core.py:1458
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
Definition: core.py:259
frontier
Definition: core.py:383
def SetReportNet(self, report_net, report_interval)
Definition: core.py:1885
def get_all_attributes(self, name)
Definition: core.py:1985
def RegisterGradient(cls, op_type)
Definition: core.py:877
_external_input_map
Definition: core.py:1154
def GetBackwardPass(self, ys)
Definition: core.py:816
def ClonePartial(self, name, inputs, outputs, remap_funcs=None)
Definition: core.py:1394
def __init__(self, name, net=None)
Definition: core.py:123
def from_blob_list(schema, values)
Definition: schema.py:826
def AddGradientOperators(self, ys, skip=0)
Definition: core.py:1543
def add_attribute(self, name, obj)
Definition: core.py:1220
def Play(self, op)
Definition: core.py:390
def get_output_producers(ssa)
Definition: core.py:1021
def RunOperatorImmediate(op)
Definition: workspace.py:420
def Python(self, f, grad_f=None, pass_workspace=False)
Definition: core.py:1726
def _InvalidateLookupTables(self)
Definition: core.py:1528
def get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
Definition: core.py:1034
def _GetSumOpOutputName(self, generator, input_name)
Definition: core.py:555
_recreate_lookup_tables
Definition: core.py:1152
def worker_init_func(func)
Definition: core.py:78
def HasSubsteps(self)
Definition: core.py:1851
def _CheckSumOpsConflict(self, out_base_name, g)
Definition: core.py:593
def get_attributes(self, name)
Definition: core.py:1227
def __init__(self, name_or_proto)
Definition: core.py:1140
gradient_frontier
Definition: core.py:384
def AddStep(self, step)
Definition: core.py:1974