3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple
9 from collections
import OrderedDict
11 from caffe2.proto
import caffe2_pb2
12 from collections
import defaultdict
13 from caffe2.python
import scope, utils, workspace
14 import caffe2.python._import_c_extension
as C
20 if (sys.platform ==
'darwin' and 'leveldb' in C.registered_dbs()):
21 print(
'If you are using homebrew leveldb on a Mac OS, you might see an ' 22 'error warning you that malloc_zone_unregister() failed. This is ' 23 'not a caffe2 issue but is due to the homebrew leveldb having an ' 24 'incompatible memory allocator. It does not affect usage.')
27 DeviceScope = scope.DeviceScope
28 NameScope = scope.NameScope
37 for name, value
in caffe2_pb2.TensorProto.DataType.items():
38 setattr(DataType, name, value)
45 basestring = basestring
51 def _GetRegisteredOperators():
55 _REGISTERED_OPERATORS = _GetRegisteredOperators()
58 def RefreshRegisteredOperators():
59 global _REGISTERED_OPERATORS
60 _REGISTERED_OPERATORS = _GetRegisteredOperators()
63 _GLOBAL_INIT_ARGS = []
67 _GLOBAL_INIT_ARGS.extend(args[1:])
71 def GetGlobalInitArgs():
72 return _GLOBAL_INIT_ARGS[:]
75 _WORKER_INIT_CALLS = []
80 By decorating a function with this, each call to the function will be 81 recorded at workflow time and replayed in each of the works at startup. 82 Used for example for registering caffe python operators. 84 def call(*args, **kwargs):
85 _WORKER_INIT_CALLS.append((func, args, kwargs))
86 return func(*args, **kwargs)
91 def GetWorkerInitCalls():
92 return _WORKER_INIT_CALLS[:]
95 def IsOperator(op_type):
96 return (op_type
in _REGISTERED_OPERATORS)
99 def IsOperatorWithEngine(op_type, engine):
100 return (op_type +
"_ENGINE_" + engine
in _REGISTERED_OPERATORS)
103 def DeviceOption(device_type, cuda_gpu_id=0, random_seed=None):
104 option = caffe2_pb2.DeviceOption()
105 option.device_type = device_type
106 option.cuda_gpu_id = cuda_gpu_id
107 if random_seed
is not None:
108 option.random_seed = random_seed
112 GradientSlice = namedtuple(
'GradientSlice', [
'indices',
'values'])
116 """A wrapper around a blob in a net. 118 BlobReference gives us a way to refer to the network that the blob is 119 generated from. Note that blobs are, essentially, just strings in the 124 """Initializes a blob reference. 126 Note that this does not prepends the namescope. If needed, use 127 ScopedBlobReference() to prepend the existing namespace. 136 return hash(self.
_name)
138 def __eq__(self, other):
139 if isinstance(other, basestring):
140 return self.
_name == other
141 elif isinstance(other, BlobReference):
142 return self.
_name == other._name
146 def __ne__(self, other):
147 return not(self == other)
153 return 'BlobReference("{}")'.format(self.
_name)
155 def __add__(self, other):
156 if not isinstance(other, basestring):
157 raise RuntimeError(
'Cannot add BlobReference to a non-string.')
160 def __radd__(self, other):
161 if not isinstance(other, basestring):
162 raise RuntimeError(
'Cannot add a non-string to BlobReference.')
168 def GetNameScope(self):
169 return self.
_name[:self.
_name.rfind(scope._NAMESCOPE_SEPARATOR) + 1]
171 def _CreateAndAddToNet(self, op_type, inputs=None, *args, **kwargs):
172 """Internal function that routes the operator generation to the 173 network's __getattr__ function. 175 inputs = []
if inputs
is None else inputs
176 if isinstance(inputs, BlobReference)
or isinstance(inputs, str):
179 inputs.insert(0, self)
183 """A wrapper allowing one to initiate operators from a blob reference. 185 Example: for a blob reference b that comes from network n, doing 187 is equivalent to doing 190 if op_type.startswith(
'__'):
191 raise AttributeError(
'Attribute {} not found.'.format(op_type))
194 'You cannot use a blob reference that does not have a net ' 195 'source to create operators. Create the operator from an ' 196 'explicit net object.')
197 if not IsOperator(op_type):
199 'Method ' + op_type +
' is not a registered operator.' +
201 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 204 op_type, *args, **kwargs)
208 """prefix the name with the current scope.""" 213 """Returns a blob reference with scope prefixed.""" 217 def _RectifyInputOutput(blobs, net=None):
218 """A helper function to rectify the input or output of the CreateOperator 221 if isinstance(blobs, basestring):
226 elif type(blobs)
is BlobReference:
229 elif type(blobs)
in (list, tuple):
233 if isinstance(blob, basestring):
235 elif type(blob)
is BlobReference:
236 rectified.append(blob)
239 "I/O blob #{} of unsupported type: {} of type {}" 240 .format(len(rectified), str(blob), type(blob)))
244 "Unknown input/output type: %s of type %s." %
245 (str(blobs), type(blobs))
260 """A function wrapper that allows one to create operators based on the 261 operator type. The type should be a string corresponding to an operator 262 registered with Caffe2. 264 operator = caffe2_pb2.OperatorDef()
265 operator.type = operator_type
268 inputs = _RectifyInputOutput(inputs)
269 outputs = _RectifyInputOutput(outputs)
270 operator.input.extend([str(i)
for i
in inputs])
271 operator.output.extend([str(o)
for o
in outputs])
273 control_input = _RectifyInputOutput(control_input)
274 operator.control_input.extend([str(i)
for i
in control_input])
280 if device_option
is not None:
281 operator.device_option.CopyFrom(device_option)
284 if engine
is not None:
285 operator.engine = engine
288 if 'random_seed' in kwargs:
289 operator.device_option.random_seed = kwargs[
'random_seed']
290 del kwargs[
'random_seed']
293 operator.arg.extend(arg)
295 for key, value
in kwargs.items():
303 def _RegisterPythonImpl(f, grad_f=None, pass_workspace=False):
304 if isinstance(f, tuple):
305 f = f[0](*f[1], **f[2])
306 if isinstance(grad_f, tuple):
307 grad_f = grad_f[0](*grad_f[1], **grad_f[2])
309 token = C.register_python_op(f, pass_workspace)
311 C.register_python_gradient_op(token, grad_f)
319 pass_workspace=False,
324 `f` should have a signature (inputs, outputs) 326 If `pass_workspace` is True, the signature is changed to 327 (inputs, outputs, workspace) where `workspace` is the workspace the op 328 is going to run on. This is potentially dangerous (as the op can manipulate 329 the workspace directly), use on your own risk. 331 kwargs[
"token"] = _RegisterPythonImpl(
332 f, grad_f, pass_workspace=pass_workspace
338 """A helper function to get the index from a gradient list, None if not 340 for i, g
in enumerate(g_list):
343 elif type(g)
is GradientSlice:
344 if (g.indices == name
or g.values == name):
349 OpSSA = namedtuple(
'OpSSA', [
'op',
'in_versions',
'out_versions'])
350 GradGenMeta = namedtuple(
'GradGenMeta', [
'grad_op',
'idx',
'gradient'])
351 SparseGradGenMeta = namedtuple(
'SparseGradGenMeta', [
352 'grad_op_indices',
'idx_indices',
353 'grad_op_values',
'idx_values',
359 """A simple IR class to keep track of all intermediate representations used 360 in the gradient computation. 363 def __init__(self, operators):
382 self.
input_usages = defaultdict(
lambda: defaultdict(list))
391 """"Adds an op to the current IR, and update the internal states to 392 reflect the blobs and versions after the execution of the op. 407 self.
ssa.append(OpSSA(op, in_versions, out_versions))
410 self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs):
411 """Checks if the gradient operators can be correctly carried out.""" 412 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
416 if original_index
is not None:
417 original_name = forward_op.output[original_index]
418 if (out_versions[original_name] !=
421 'Gradient name "%s" is expected to correspond ' 422 'to version %d of "%s", but currently we have ' 424 grad_op_input, out_versions[original_name],
429 elif grad_op_input
in out_versions:
430 if self.
frontier[grad_op_input] != out_versions[grad_op_input]:
432 'Gradient operator needs output "%s" at version' 433 ' %d, but currently we have version %d.' % (
434 grad_op_input, out_versions[grad_op_input],
440 elif grad_op_input
in in_versions:
441 if (self.
frontier[grad_op_input] != in_versions[grad_op_input]):
443 'Gradient operator needs input "%s" at version ' 444 '%d, but currently we have version %d.' % (
445 grad_op_input, in_versions[grad_op_input],
452 if grad_op_input
not in locally_generated_blobs:
454 'Blob name "%s" not in the scope of operator: ' 455 '%s\nand is not generated by any of the local ' 456 'gradient operators.' % (grad_op_input, str(forward_op))
459 def AppendSparseGenerators(self, sparse_generators):
461 for name, input_generators
in sparse_generators.items():
462 for version, generators
in input_generators.items():
463 if len(generators) == 1:
465 generator = generators[0]
468 assert(len(generators) == 2)
469 op1_i, idx1_i, op1_v, idx1_v, g1 = generators[0]
470 op2_i, idx2_i, op2_v, idx2_v, g2 = generators[1]
472 assert(op1_i
is None or op2_i
is None)
473 assert(op1_v
is None or op2_v
is None)
474 assert(idx1_i == 0
or idx2_i == 0)
475 assert(idx1_v == 0
or idx2_v == 0)
476 generator = SparseGradGenMeta(
477 op1_i
or op2_i, idx1_i + idx2_i,
478 op1_v
or op2_v, idx1_v + idx2_v,
483 self, fwd_op_idx, gradient_ops, g_output, g_input):
484 """Updates gradient_generators and gradient_frontier""" 485 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
486 locally_generated_blobs = []
487 sparse_generators = defaultdict(
lambda: defaultdict(list))
489 for grad_op
in gradient_ops:
491 for s
in grad_op.input:
493 s, g_output, fwd_op_idx, locally_generated_blobs)
498 locally_generated_blobs.extend([str(s)
for s
in grad_op.output])
499 for i, output
in enumerate(grad_op.output):
501 if input_index
is not None:
502 input_name = forward_op.input[input_index]
503 input_version = in_versions[input_name]
504 g = g_input[input_index]
505 if type(g)
is GradientSlice:
511 if g.indices == output:
512 m = SparseGradGenMeta(grad_op, i,
None, 0, g)
514 assert(g.values == output)
515 m = SparseGradGenMeta(
None, 0, grad_op, i, g)
516 sparse_generators[input_name][input_version].append(m)
531 for input_index, g
in enumerate(g_input):
532 input_name = forward_op.input[input_index]
533 input_version = in_versions[input_name]
536 if type(g)
is GradientSlice:
537 if str(g.indices)
not in locally_generated_blobs
and \
538 str(g.values)
not in locally_generated_blobs:
540 SparseGradGenMeta(
None, 0,
None, 0, g))
542 if str(g)
not in locally_generated_blobs:
544 GradGenMeta(
None, 0, g))
549 for i, g
in enumerate(g_input):
551 input_name = forward_op.input[i]
552 input_version = in_versions[input_name]
555 def _GetSumOpOutputName(self, generator, input_name):
556 def remove_suffix(s, suffix):
557 if s.endswith(suffix):
558 return s[:-len(suffix)]
562 if type(g)
is GradGenMeta:
565 return grad_op.output[idx]
567 assert(type(g)
is SparseGradGenMeta)
568 op_i, idx_i, op_v, idx_v, _ = g
570 return remove_suffix(op_i.output[idx_i],
'_indices')
572 return remove_suffix(op_v.output[idx_v],
'_values')
574 return input_name +
'_grad' 576 def _SetSumOpsDeviceOption(self, sum_ops, generators):
579 for generator
in generators:
580 grad_op = generator.grad_op
if type(generator)
is GradGenMeta \
581 else generator.grad_op_values
or generator.grad_op_indices
583 if grad_op.HasField(
'device_option'):
585 op.device_option.CopyFrom(grad_op.device_option)
588 def _DisambiguateGradOpOutput(self, grad_op, idx, cnt):
589 grad_op.output[idx] = (
590 '_' + grad_op.output[idx] +
'_autosplit_{}'.format(cnt))
591 return grad_op.output[idx], cnt + 1
593 def _CheckSumOpsConflict(self, out_base_name, g):
594 if str(out_base_name) == str(g):
597 'The gradient output of empty gradient op can not ' 598 'be the same as the normal name of the current ' 601 def _MakeDenseSumOps(self, generators, out_base_name):
605 for generator
in generators:
606 grad_op, idx, g = generator
607 assert(type(g)
is not GradientSlice)
610 sum_op_input.append(out)
613 sum_op_input.append(str(g))
617 map(BlobReference, sum_op_input),
619 return sum_ops, out_base_name
621 def _MakeSparseSumOps(self, generators, out_base_name):
622 indices_concat_input = []
623 values_concat_input = []
627 for generator
in generators:
628 assert(type(generator)
is SparseGradGenMeta)
629 op_i, idx_i, op_v, idx_v, g = generator
632 indices_concat_input.append(out)
635 indices_concat_input.append(g.indices)
638 values_concat_input.append(out)
641 values_concat_input.append(g.values)
643 indices_concat_output = out_base_name +
'_indices_concat' 644 indices_concat_split = out_base_name +
'_indices_concat_split' 645 values_concat_output = out_base_name +
'_values_concat' 646 values_concat_split = out_base_name +
'_values_concat_split' 654 map(BlobReference, indices_concat_input),
656 [indices_concat_output, indices_concat_split]),
661 map(BlobReference, values_concat_input),
662 map(BlobReference, [values_concat_output, values_concat_split]),
666 sum_op_output = GradientSlice(
667 indices=indices_concat_output,
668 values=values_concat_output,
670 return sum_ops, sum_op_output
672 def _MakeSumOps(self, input_name, input_version):
675 types = list(set(type(x)
for x
in generators))
676 assert(len(types) == 1)
677 if types[0]
is GradGenMeta:
680 assert(types[0]
is SparseGradGenMeta)
685 def _VerifyGradientGenerators(self, generator):
688 if len({type(g)
for g
in generator}) > 1:
690 'Automatic aggregation of a mix of sparse and dense gradients ' 691 'is not supported yet')
696 if len(generator) < 2:
699 all_gradient_names = []
700 all_device_options = []
702 if type(g)
is GradGenMeta:
704 all_gradient_names.append(g.gradient)
705 all_device_options.append(g.grad_op.device_option)
707 assert(type(g)
is SparseGradGenMeta)
708 if g.grad_op_indices:
709 all_device_options.append(g.grad_op_indices.device_option)
711 all_device_options.append(g.grad_op_values.device_option)
712 all_gradient_names.append(g.gradient.values)
715 if len(set(all_gradient_names)) > 1:
716 raise RuntimeError(
'Unexpected behavior: not all grad output ' 717 'names are the same.')
719 if len(all_device_options) >= 2
and not all(
720 d == all_device_options[0]
for d
in all_device_options[1:]):
721 raise RuntimeError(
'Unexpected behavior: not all grad ops' 722 'have the same device option.')
726 """For each input name in the forward op, check if we will need to 727 add gradient accumulation. If so, do gradient accumulation and return 728 the list of gradient operators. 730 The criteria for doing gradient accumulation is: 731 (1) the specific input version has been used by multiple operators. 732 (2) the current fwd_op_idx is the first to use that input, i.e. in the 733 backward pass, is the last to optionally generate the gradient for 735 (3) For the operators that used the input, their gradient operators 736 have generated more than 1 gradient. 738 When accumulating operators, our current solution is to rename all the 739 created gradients with an internal intermediate name, and then add a 740 Sum() operator that adds up all the gradients. This may use more memory 741 due to intermediate storage, but is usually the fastest approach as one 742 can do one single sum for multiple intermediate gradients. 744 forward_op, in_versions, out_versions = self.
ssa[fwd_op_idx]
745 additional_sum_ops = []
747 for _i, input_name
in enumerate(set(forward_op.input)):
748 input_version = in_versions[input_name]
749 input_usage = self.
input_usages[input_name][input_version]
750 if (len(input_usage) <= 1
or fwd_op_idx != input_usage[0]):
757 except RuntimeError
as err:
759 "Gradients for param ''{}'' failed to verify: {}".format(
766 sum_ops, g = self.
_MakeSumOps(input_name, input_version)
767 additional_sum_ops.extend(sum_ops)
768 grad_map[input_name] = g
769 return additional_sum_ops, grad_map
771 def _GetInitGradients(self, ys):
774 for y, g
in ys.items():
777 "ConstantFill", [y], [str(y) +
"_autogen_grad"],
779 gradient_ops.append(autograd_op)
780 g = autograd_op.output[0]
783 input_to_grad[str(y)] = (
784 GradientSlice(str(g[0]), str(g[1]))
785 if isinstance(g, GradientSlice)
else str(g))
787 return input_to_grad, gradient_ops
789 def _GenerateGradientsForForwardOp(
790 self, forward_op_idx, input_to_grad):
791 new_input_to_grad = {}
793 forward_op, in_versions, out_versions = self.
ssa[forward_op_idx]
795 input_to_grad.get(name,
None)
for name
in forward_op.output)
796 if not all(g
is None for g
in g_output):
797 gradient_ops, g_input = GradientRegistry.GetGradientForOp(
798 forward_op, g_output)
802 forward_op_idx, gradient_ops, g_output, g_input)
804 for name, grad
in zip(forward_op.input, g_input):
809 if grad
is not None or \
810 name
not in input_to_grad
or \
811 name
in list(forward_op.output):
812 new_input_to_grad[name] = grad
814 return new_input_to_grad, gradient_ops
817 """Gets the backward pass that computes the derivatives of given blobs. 820 ys: a list or a dictionary specifying what blobs we want to compute 821 derivatives of. If the input is a list, we will automatically 822 generate their gradients with all-one values; if the input is a 823 dictionary, for any dictionary entries that are not None, we will 824 take the corresponding blobs as their gradients; for all those 825 that are None, we will auto-fill them with 1. 827 if isinstance(ys, list):
828 ys = dict((y,
None)
for y
in ys)
829 elif not isinstance(ys, dict):
830 raise TypeError(
"ys should either be a list or a dict.")
834 for y, _
in ys.items():
844 for forward_op_idx
in reversed(range(len(self.
ssa))):
846 forward_op_idx, all_input_to_grad)
847 all_input_to_grad.update(input_to_grad)
848 all_gradient_ops += gradient_ops
856 all_input_to_grad.update(grad_map)
857 all_gradient_ops += additional_sum_ops
863 all_input_to_grad_out = {}
864 for key, val
in all_input_to_grad.items():
869 return all_gradient_ops, all_input_to_grad_out
873 """GradientRegistry holds the mapping from operators to their gradients.""" 874 gradient_registry_ = {}
878 """A decorator for registering gradient mappings.""" 887 def _GetGradientForOpCC(cls, op_def, g_output):
889 def from_untyped(grad):
891 w = C.GradientWrapper()
895 (indices, values) = grad
896 w = C.GradientWrapper()
902 w = C.GradientWrapper()
907 g_output = [from_untyped(grad)
for grad
in g_output]
908 grad_defs_str, g_input = C.get_gradient_defs(
909 op_def.SerializeToString(), g_output)
911 def to_untyped(grad_wrapper):
912 if grad_wrapper.is_empty():
914 if grad_wrapper.is_sparse():
915 return GradientSlice(grad_wrapper.indices, grad_wrapper.values)
916 assert grad_wrapper.is_dense()
917 return grad_wrapper.dense
919 g_input = [to_untyped(grad_wrapper)
for grad_wrapper
in g_input]
921 for grad_def_str
in grad_defs_str:
922 grad_def = caffe2_pb2.OperatorDef()
923 grad_def.ParseFromString(grad_def_str)
924 grad_defs.append(grad_def)
925 return grad_defs, g_input
928 def GetGradientForOp(cls, op, g_output):
931 except Exception
as e:
939 "No gradient registered for {}. ".format(op.type) +
940 "Exception from creating the gradient op: {}.".format(e))
942 if gradient_ops
is None:
944 if type(gradient_ops)
is not list:
945 gradient_ops = [gradient_ops]
946 return gradient_ops, g_input
950 """Gets the backward pass for the list of operators. 953 operators: a list of operators constituting the forward pass. 954 ys: a list or a dictionary specifying what blobs we want to compute 955 derivatives of. If the input is a list, we will automatically 956 generate their gradients with all-one values; if the input is a 957 dictionary, for any dictionary entries that are not None, we'll 958 take the corresponding blobs as their gradients; for all those 959 that are None, we will auto-fill them with 1. 961 gradient_ops: a list of gradient operators to run. 962 all_input_to_grads: a map from input to their corresponding 966 return ir.GetBackwardPass(ys)
971 Given a net, return a structure containing the version of each input and 972 output blob used by each operator. 975 net: either a Net or a NetDef 976 blob_versions: (optional) map with current version number for given 977 blob names. If not provided or blob not found, start 980 Tuple (ssa, blob_versions) 981 ssa: list of tuples (versioned_inputs, versioned_outputs) 982 for each op in the net. A versioned input is a tuple 983 (blob_name, version). 984 blob_versions: updated map with latest version of each blob found in 987 proto = net.Proto()
if isinstance(net, Net)
else net
988 assert isinstance(proto, caffe2_pb2.NetDef)
989 if blob_versions
is None:
991 if isinstance(net, list):
992 return [
get_ssa(n, blob_versions)
for n
in net], blob_versions
993 for i
in proto.external_input:
994 if i
not in blob_versions:
995 blob_versions[str(i)] = 0
998 if not proto.external_input:
1000 if i
not in blob_versions:
1001 blob_versions[i] = 0
1002 inputs = [(str(i), blob_versions.get(str(i), 0))
for i
in op.input]
1004 blob_versions[str(o)] = blob_versions.get(str(o), 0) + 1
1005 outputs = [(str(o), blob_versions[str(o)])
for o
in op.output]
1006 ssa.append((inputs, outputs))
1007 return ssa, blob_versions
1012 Given a ssa in the format produced by get_ssa(), return a set of blobs that 1013 are used before they are defined, which corresponds to inputs at version 0. 1016 for inputs, _outputs
in ssa:
1017 undef_blobs |= set(name
for (name, ver)
in inputs
if ver == 0)
1023 Given a ssa in the format produced by get_ssa(), returns a map from 1024 versioned blob into the operator index that produces that version of 1025 the blob. A versioned blob is a tuple (blob_name, version). 1028 for i, (_inputs, outputs)
in enumerate(ssa):
1036 Given a ssa and blob_versions as produced by get_ssa(), returns the list 1037 of op indices that are necessary in order to generate the blobs in 1038 `outputs`, given blobs in `inputs`. 1039 Consider that the `inputs` are given in their latest version. 1041 inputs_set = set((str(i), blob_versions[str(i)])
for i
in inputs)
1043 queue = [(str(o), blob_versions[str(o)])
for o
in outputs]
1045 while len(queue) > 0:
1047 if (o
not in inputs_set)
and (o
in producers):
1048 op_id = producers[o]
1049 if op_id
not in used_op_ids:
1050 used_op_ids |= {op_id}
1051 inputs, _ = ssa[op_id]
1052 queue.extend(inputs)
1053 return sorted(used_op_ids)
1059 Clone the given Net, binding its input schema to the given `inputs` record. 1060 Blob names defined by the net are prepended with the given `prefix`. 1063 net: the net to clone 1064 name: the name of the new net 1065 prefix: the prefix to append to local blobs 1066 blob_remap: (optional) dict with additional blob name remapping. 1067 inputs: (optional) input record that will provide actual input 1068 values for the cloned net. Must be compatible with the 1069 net's input schema or be a strict superset of it 1070 keep_schema: by default (True), the original schema will be kept and 1071 remapped accordingly. otherwise, the schema will be set as 1072 inputs or left empty if inputs is not given. 1074 Tuple (cloned_net, blob_remap) 1075 clone_net: the cloned Net 1076 blob_remap: a map from original blob names into remapped blob names 1078 from caffe2.python
import schema
1079 assert isinstance(net, Net)
1080 if blob_remap
is None:
1082 if inputs
is not None:
1084 original = net.input_record()
1085 assert original
is not None 1087 diff = set(original.field_names()) - set(inputs.field_names())
1088 assert len(diff) == 0, \
1089 "Schemas don't match, extra fields {} found in the net".format(diff)
1090 original_mapping = dict(zip(original.field_names(),
1091 original.field_blobs()))
1092 for fn, fb
in zip(inputs.field_names(), inputs.field_blobs()):
1093 if fn
in original_mapping:
1094 blob_remap[str(original_mapping[fn])] = str(fb)
1096 ssa, blob_versions =
get_ssa(proto)
1099 for blob
in blob_versions.keys():
1100 if blob
in blob_remap:
1102 elif blob
in undef_blobs:
1103 blob_remap[blob] = blob
1105 blob_remap[blob] = prefix + blob
1106 cloned_net = net.Clone(name, blob_remap, keep_schema=keep_schema)
1107 if not keep_schema
and inputs:
1108 cloned_net.set_input_record(inputs)
1109 return cloned_net, blob_remap
1112 def _get_blob_ref(blob_name_or_ref):
1114 blob_name_or_ref
if isinstance(input, BlobReference)
1120 _net_names_used = set()
1121 operator_registry_ = {}
1124 def current_prefix():
1125 from caffe2.python.net_builder
import NetBuilder
1126 builder = NetBuilder.current(required=
False)
1127 return builder.name
if builder
else '' 1130 def _get_next_net_name(basename):
1131 name = basename =
'/'.join(filter(
1132 lambda x: x, (Net.current_prefix(), basename)))
1134 while name
in Net._net_names_used:
1135 name = basename +
'_' + str(next_idx)
1137 Net._net_names_used |= set([name])
1144 name_or_proto: If a NetDef is provided, clone it. Otherwise, 1145 create an empty net with the given name. 1156 if type(name_or_proto)
is caffe2_pb2.NetDef:
1157 proto = name_or_proto
1160 self.
_net = caffe2_pb2.NetDef()
1161 self.
_net.CopyFrom(proto)
1163 existing_outputs = [list(op.output)
for op
in self.
_net.op]
1168 existing_names = set(
1170 [list(op.input)
for op
in self.
_net.op], []
1172 existing_outputs, []
1175 for outs
in existing_outputs:
1178 prefix_len = len(self.
_net.name +
'_blob_')
1179 autogen_indices = []
1180 for s
in existing_names:
1181 if s.startswith(self.
_net.name +
'_blob_'):
1183 autogen_indices.append(int(s[prefix_len]))
1186 if len(autogen_indices):
1190 name = self.
_net.name
1192 name = name_or_proto
1193 self.
_net = caffe2_pb2.NetDef()
1197 self.
_net.name = Net._get_next_net_name(name)
1199 def AppendNet(self, net):
1200 assert isinstance(net, Net)
1202 self.
Proto().external_input.extend(
1203 [i
for i
in net.Proto().external_input
1204 if i
not in self.
Proto().external_input])
1205 self.
Proto().external_output.extend(
1206 [o
for o
in net.Proto().external_output
1207 if o
not in self.
Proto().external_output])
1210 def LogInfo(self, *msg_or_blobs):
1211 for msg_or_blob
in msg_or_blobs:
1212 if not isinstance(msg_or_blob, BlobReference):
1213 blob = self.GivenTensorStringFill(
1215 shape=[], values=[msg_or_blob])
1218 self.Print(blob, [])
1222 Add `obj` to the list of attributes in this net under the given `name`. 1223 Attributes are user-defined objects and have no pre-defined semantics. 1229 Returns the list of attributes in this net for a given `name`. 1230 Attributes are user-defined objects added with `add_attribute'. 1236 Adds a random seed to each op in the net. 1237 If sequence_seed is set, the i-th op has rand_seed=`seed + i` 1238 If seed_on_op_def is set, the op rand_seed=hash(str(op)) 1239 sequence_seed and seed_on_op_def cannot be both set to True. 1241 assert not (sequence_seed
and seed_on_op_def), (
1242 'sequence_seed and seed_on_op_def cannot be both set to True.')
1243 for i, op
in enumerate(self.
Proto().op):
1245 curr_seed = seed + i
1246 elif seed_on_op_def:
1247 curr_seed = hash(str(op) + str(seed)) % np.iinfo(np.uint32).max
1250 op.device_option.random_seed = curr_seed
1253 return self.
_net.name
1258 def Const(self, array, blob_out=None, dtype=None):
1259 if isinstance(array, bool):
1260 return self.ConstantFill(
1263 dtype=DataType.BOOL,
1267 array = np.array(array)
1269 array = np.array(array, dtype=dtype)
1271 def do_set(operator):
1276 values=array.flatten().tolist())
1278 if array.dtype == np.int32:
1279 return do_set(self.GivenTensorIntFill)
1280 elif array.dtype == np.int64:
1281 return do_set(self.GivenTensorInt64Fill)
1282 elif array.dtype == np.str:
1283 return do_set(self.GivenTensorStringFill)
1285 return do_set(self.GivenTensorFill)
1289 Returns true if the given BlobReference is produced as output of 1290 an operator in this net, or if it is provided as an external input. 1299 Returns true iff the given BlobReference is used by any operator 1300 or this net, or if it is one of the external inputs of the net. 1302 blob_name = str(blob)
1303 for op
in self.
_net.op:
1304 for input
in op.input:
1305 if input == blob_name:
1311 Given the name of a blob produced by this net, return a BlobReference 1312 to it. If the blob is not produced by any op in this net, 1315 blob_name = str(blob_name)
1317 raise KeyError(
'Net does not define blob %s' % blob_name)
1331 name: name of the cloned net 1332 blob_remap: optional map with list of blob names to replace 1333 op_id_mask: optional list of operator indices to include in 1334 the cloned net. If not provided, all ops are included. 1336 if remap_funcs
is None:
1339 new_proto = caffe2_pb2.NetDef()
1340 new_proto.CopyFrom(proto)
1341 new_proto.name = name
1343 if blob_remap
is None:
1345 if op_id_mask
is None:
1346 op_id_mask = range(0, len(proto.op))
1348 def get_remapped_str(blob):
1349 blob_str = str(blob)
1350 return str(blob_remap.get(blob_str, blob_str))
1352 def remap_list(proto_list):
1353 new_list = [get_remapped_str(b)
for b
in proto_list]
1355 proto_list.extend(new_list)
1358 new_op = caffe2_pb2.OperatorDef()
1360 remap_list(new_op.input)
1361 remap_list(new_op.output)
1362 if new_op.type
in remap_funcs:
1363 remap_funcs[new_op.type](new_op, (name +
'/')
if name
else '')
1367 new_proto.op.extend([remap_op(proto.op[op_id])
for op_id
in op_id_mask])
1368 remap_list(new_proto.external_input)
1369 remap_list(new_proto.external_output)
1370 new_net =
Net(new_proto)
1373 from caffe2.python
import schema
1396 Clone this net, including only ops that are necessary in order to 1397 compute `outputs` given `inputs`. Return references to the cloned 1398 outputs. Internal blobs (blobs that are produced and consumed inside 1399 the net but not used as outputs) will be remapped to avoid name 1403 name: the name of the cloned net 1404 inputs: map where the keys correspond to BlobReferences in the 1405 original net, and the values correspond to external inputs 1406 in the partially cloned net. If `inputs` is a list, don't 1408 outputs: outputs to be produced by the cloned net. 1411 Tuple (new_net, new_outputs) 1412 new_net: a new Net object. 1413 new_outputs: list of BlobReferences corresponding to the 1414 outputs produced by new_net. 1416 input_is_pair_list = isinstance(inputs, list)
and all(
1417 isinstance(i, tuple)
and len(i) == 2
for i
in inputs)
1419 inputs
if isinstance(inputs, (dict, OrderedDict))
else 1420 OrderedDict(inputs)
if input_is_pair_list
else 1421 OrderedDict(zip(inputs, inputs)))
1422 for output
in outputs:
1424 input_names = {str(k): str(v)
for k, v
in inputs.items()}
1425 output_names = [str(o)
for o
in outputs]
1427 ssa, blob_versions =
get_ssa(proto)
1430 assert len(set(used_op_ids) & set(disallowed_op_ids)) == 0, (
1431 'Cannot partially clone net: some of the ops required would ' +
1432 'generate the given input.')
1434 sub_ssa = [op
for i, op
in enumerate(ssa)
if i
in used_op_ids]
1436 prefix = (name +
'/')
if name
else '' 1438 def remap(blob_name):
1439 if blob_name
in input_names:
1440 return input_names[blob_name]
1441 elif blob_name
in undef_blobs:
1444 return prefix + blob_name
1446 blob_mapping = {b: remap(b)
for b
in blob_versions.keys()}
1447 new_net = self.
Clone(name, blob_mapping, used_op_ids, remap_funcs)
1449 blob_mapping[i]
for i
in input_names.keys()] + list(undef_blobs)
1450 new_out = [blob_mapping[o]
for o
in output_names]
1451 del new_net.Proto().external_input[:]
1452 new_net.Proto().external_input.extend(new_in)
1453 new_net._external_input_map = set(list(new_in))
1454 del new_net.Proto().external_output[:]
1455 new_net.Proto().external_output.extend(new_out)
1456 return new_net, [new_net.GetBlobRef(o)
for o
in new_out]
1463 """Return the blob that has not been defined or registered in the 1464 current net. It returns `ScopedBlobReference(prefix)`, if it's valid, 1465 otherwise `ScopedBlobReference(prefix) + '_auto_' + ?`. Different calls 1466 is guaranteed to return blob with different names. 1469 return self.
NextBlob(output_blob_base)
1472 """Return the blob that has not been defined or registered in the 1473 current net. It returns `BlobReference(prefix)`, if it's valid, 1474 otherwise `BlobReference(prefix) + '_auto_' + ?`. Different calls 1475 is guaranteed to return blob with different names.""" 1477 output_blob = output_blob_base
1481 output_blob = output_blob_base +
'_auto_' + str(index)
1488 """Returns the next name to be used, if you do not want to explicitly 1489 name your blob. [Deprecated, use NextBlob, NextScopedBlob instead]""" 1491 output_name_base = self.
_net.name +
'/' + prefix
1492 output_name = output_name_base
1493 if output_id
is not None:
1494 output_name +=
':' + str(output_id)
1497 output_name = output_name_base +
'_' + str(index)
1498 if output_id
is not None:
1499 output_name +=
':' + str(output_id)
1504 return str(output_name)
1506 def _ExtendOps(self, new_ops):
1507 self.
_net.op.extend(new_ops)
1509 self.
_op_outputs.update([str(o)
for o
in op.output])
1511 def _CheckLookupTables(self):
1513 Called from unit tests to validate the internal lookup tables 1514 match the protobuf contents. 1516 test_op_outputs = set()
1517 for op
in self.
_net.op:
1519 test_op_outputs.add(o)
1521 test_external_inp = set()
1522 for inp
in self.
_net.external_input:
1523 test_external_inp.add(inp)
1525 assert test_op_outputs.difference(self.
_op_outputs) == set()
1528 def _InvalidateLookupTables(self):
1531 def _RecreateLookupTables(self):
1533 for op
in self.
_net.op:
1538 for inp
in self.
_net.external_input:
1544 """Add the gradient for operators in the net. 1547 ys: a list or a dictionary specifying what blobs we want to compute 1548 derivatives of. If the input is a list, we will automatically 1549 generate their gradients with all-one values; if the input is a 1550 dictionary, for any dictionary entries that are not None, we will 1551 take the corresponding blobs as their gradients; for all those 1552 that are None, we will auto-fill them with 1. 1553 skip: skips the first n operators. This is provided mainly because a 1554 lot of nets may use the first few operators for data generation 1555 like stuff which really do not need to have gradients. 1558 returns a map from the blob name in the input network to a blob 1559 containing gradient or a GradientSlice in case of sparse gradient 1561 Currently, this is hard-coded for float operators if there are branches 1562 (i.e. a blob is used as input to multiple operators). This is because 1563 the gradient accumulation (Sum) is float only right now. 1566 grad_ops, input_to_grad = GradientRegistry.GetBackwardPass(
1567 self.
_net.op[skip:], ys)
1575 return input_to_grad
1577 def AddExternalInput(self, *inputs):
1578 assert len(inputs) > 0
1580 for input
in inputs:
1581 input_name = str(input)
1583 'Net already contains an input named %s' % input_name)
1584 for input
in inputs:
1585 input_name = str(input)
1586 self.
_net.external_input.extend([input_name])
1588 refs.append(_get_blob_ref(input_name))
1590 return refs[0]
if len(refs) == 1
else refs
1592 def AddExternalOutput(self, *outputs):
1593 for output
in outputs:
1594 assert isinstance(output, BlobReference)
1596 for output
in outputs:
1597 self.
Proto().external_output.extend([str(output)])
1599 def AddScopedExternalInputs(self, *inputs):
1604 def AddScopedExternalOutputs(self, *outputs):
1610 def external_inputs(self):
1611 return map(_get_blob_ref, self.
_net.external_input)
1614 def external_outputs(self):
1615 return map(_get_blob_ref, self.
_net.external_output)
1617 def set_input_record(self, input_record):
1618 from caffe2.python
import schema
1620 'Input schema cannot be reset')
1621 if not input_record.has_blobs():
1622 with NameScope(self.
Name()):
1626 for blob
in input_record.field_blobs():
1631 def set_output_record(self, record):
1633 'Output record cannot be reset')
1634 for blob
in record.field_blobs():
1636 for blob
in record.field_blobs():
1640 def AppendOutputRecordField(self, field_name, record):
1641 from caffe2.python
import schema
1643 'Tried to append to missing output record' 1645 for blob
in record.field_blobs():
1647 for blob
in record.field_blobs():
1650 (field_name, record)
1653 def input_record(self):
1656 def output_record(self):
1659 def AddExternalInputs(self, *inputs):
1662 def AddExternalOutputs(self, *outputs):
1665 def DeduplicateGradientSlices(self, g, aggregator='sum'):
1666 assert isinstance(g, GradientSlice)
1667 unique, remapping = self.Unique([g.indices], 2, engine=
'SparseHash')
1668 if aggregator.lower() ==
'sum':
1669 new_g = self.UnsortedSegmentSum([g.values, remapping], 1)
1670 elif aggregator.lower() ==
'mean':
1671 new_g = self.UnsortedSegmentMean([g.values, remapping], 1)
1673 raise ValueError(
'{} is not supported'.format(aggregator))
1674 return GradientSlice(indices=unique, values=new_g)
1677 """A convenient function to run everything on the GPU.""" 1678 device_option = caffe2_pb2.DeviceOption()
1679 device_option.device_type = caffe2_pb2.CUDA
1680 device_option.cuda_gpu_id = gpu_id
1681 self.
_net.device_option.CopyFrom(device_option)
1683 for op
in self.
_net.op:
1686 def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, **kwargs):
1687 """A helper function to create an operator and add it to self. 1689 inputs = _RectifyInputOutput(inputs)
1690 for input
in inputs:
1692 assert input.Net() != self
1697 outputs = self.
NextName(prefix=op_type)
1698 elif type(outputs)
is int:
1702 self.
NextName(prefix=op_type, output_id=i)
1703 for i
in range(outputs)]
1704 outputs = _RectifyInputOutput(outputs, net=self)
1707 if len(op.output) == 0:
1709 elif len(op.output) == 1:
1712 return tuple(
BlobReference(str(o), self)
for o
in op.output)
1714 def __getattr__(self, op_type):
1715 if op_type.startswith(
'__'):
1716 raise AttributeError(
'Attribute {} not found.'.format(op_type))
1717 if not IsOperator(op_type)
and not IsOperatorWithEngine(op_type,
"CUDNN"):
1719 'Method ' + op_type +
' is not a registered operator.' +
1720 ' Did you mean: [' +
1721 ",".join(workspace.C.nearby_opnames(op_type)) +
']' 1724 op_type, *args, **kwargs)
1726 def Python(self, f, grad_f=None, pass_workspace=False):
1728 Registers and returns a python operator. 1730 `f` and `f_grad` can be one of the following: 1731 - a function with signature (inputs, outputs), where inputs and 1732 outputs are a list of CPUTensor objects. This function will be 1733 called from C++ everytime the operator is executed. 1734 - a tuple (func, args, kwargs), here `func` is a callable, args is 1735 an argument list, and kwargs is a dict list. The call: 1736 f = func(*args, kwargs) 1737 will be performed locally at node initialization time, on all of 1738 the nodes of the job, returning `f`, a callable that will be used 1739 as the python operator function to be called during Net execution. 1740 This is to be used when using python operator in a distributed 1741 context, and allows to create and keep local python state across 1742 calls to the operator. 1744 If `pass_workspace` is True, the signature is changed to 1745 (inputs, outputs, workspace) where `workspace` is the workspace the op 1746 is going to run on. This is potentially dangerous (as the op can 1747 manipulate the workspace directly), use on your own risk. 1749 assert(IsOperator(
'Python'))
1750 if isinstance(f, tuple)
or isinstance(grad_f, tuple):
1756 registry = _RegisterPythonImpl
1757 token = registry(f, grad_f, pass_workspace=pass_workspace)
1759 'Python', token=token, *args, **kwargs)
1762 def get_net_name(netlike):
1763 if isinstance(netlike, Net):
1764 return netlike.Proto().name
1765 elif isinstance(netlike, caffe2_pb2.NetDef):
1773 Ensures that the output of an operator is a list. 1774 Use when an operator has a variable number of outputs, but a list of 1775 outputs is desired even when number of outputs is 1. 1778 op_output: Either a BlobReferenece or an iterable of BlobReferences. 1781 A list of BlobReferences. 1783 assert type(op_output)
in (list, tuple, BlobReference)
1786 if isinstance(op_output, BlobReference)
else list(op_output))
1789 def _add_net_to_dict(net_dict, net):
1790 name = get_net_name(net)
1791 if name
in net_dict:
1792 assert net_dict[name]
is None or net == net_dict[name], (
1793 'Different nets with same name: ' + name)
1796 net_dict[name] = net
if isinstance(net, Net)
else None 1801 _step_names_used = set()
1804 def _get_next_step_name(basename):
1807 while name
in ExecutionStep._step_names_used:
1808 name = basename +
'_' + str(next_idx)
1810 ExecutionStep._step_names_used |= set([name])
1813 def __init__(self, name, nets=None, num_iter=None):
1814 self.
_step = caffe2_pb2.ExecutionStep()
1815 self.
_step.name = name
or ExecutionStep._get_next_step_name(
'step')
1819 if nets
is not None:
1820 if type(nets)
is Net:
1823 if _add_net_to_dict(self.
_net_dict, net):
1824 self.
_step.network.extend([get_net_name(net)])
1825 if num_iter
is not None:
1826 self.
_step.num_iter = num_iter
1828 def get_net(self, name):
1832 return self.
_step.name
1835 return self.
_step.name
1837 def _assert_can_mutate(self):
1839 'Cannot mutate a step that has already been added to a plan/step.')
1841 def _notify_is_used(self):
1848 return self.
_step.network
is not None and (
1849 len(self.
_step.network) > 0)
1851 def HasSubsteps(self):
1852 return self.
_step.substep
is not None and (
1853 len(self.
_step.substep) > 0)
1861 def SetIter(self, num_iter):
1863 self.
_step.num_iter = num_iter
1865 def SetOnlyOnce(self, only_once):
1867 self.
_step.only_once = only_once
1869 def SetShouldStopBlob(self, should_stop_blob):
1870 assert isinstance(should_stop_blob, BlobReference), (
1871 "expects BlobReference here, got {}".format(type(should_stop_blob)))
1873 self.
_step.should_stop_blob = str(should_stop_blob)
1877 Run this step every interval millisecods, as long as its 1878 siblings are still running. It is guaranteed that, after all 1879 siblings finish, this step will run at least one. 1881 This property is ignored for top-level ExecutionSteps. 1883 self.
_step.run_every_ms = interval
1886 """ DEPRECATED. Use RunEveryMillis instead. """ 1888 _add_net_to_dict(self.
_net_dict, report_net)
1889 self.
_step.report_net = get_net_name(report_net)
1890 self.
_step.report_interval = report_interval
1892 def AddSubstep(self, substep):
1894 assert not self.
HasNets(),
'Cannot have both network and substeps.' 1895 if isinstance(substep, ExecutionStep):
1896 substep._notify_is_used()
1897 if not substep.HasNets()
and not substep.HasSubsteps():
1899 for net
in substep.Nets():
1902 proto = substep.Proto()
1905 self.
_step.substep.add().CopyFrom(proto)
1908 def SetConcurrentSubsteps(self, concurrent_substeps):
1910 assert not self.
HasNets(),
'Cannot have both network and substeps.' 1911 self.
_step.concurrent_substeps = concurrent_substeps
1913 def AddNet(self, net):
1915 assert not self.
HasSubsteps(),
'Cannot have both network and substeps.' 1916 assert isinstance(net, Net)
1918 self.
_step.network.extend([get_net_name(net)])
1923 Return the list of all attributes under the given `name`, present in 1924 all of the nets used in this execution step and its children. 1928 objs += net.get_attributes(name)
1932 def add_nets_in_order(step, net_list):
1933 proto = step.Proto()
1934 for substep
in step.Substeps():
1935 add_nets_in_order(substep, net_list)
1936 for net
in proto.network:
1937 if net
not in net_list:
1938 net_list.append(net)
1942 if proto.report_net
and proto.report_net
not in net_list:
1943 net_list.append(proto.report_net)
1948 def __init__(self, name_or_step):
1949 self.
_plan = caffe2_pb2.PlanDef()
1951 if isinstance(name_or_step, ExecutionStep):
1952 self.
_plan.name = name_or_step.Name()
1954 elif isinstance(name_or_step, basestring):
1955 self.
_plan.name = name_or_step
1957 raise ValueError(
'name_or_step must be a string or ExecutionStep')
1960 return self.
_plan.name
1965 def AddNets(self, nets):
1967 if _add_net_to_dict(self.
_net_dict, net):
1968 assert isinstance(net, Net)
1969 self.
_plan.network.add().CopyFrom(net.Proto())
1974 def AddStep(self, step):
1975 assert isinstance(step, ExecutionStep)
1976 step._notify_is_used()
1977 if not step.HasNets()
and not step.HasSubsteps():
1979 self.
_plan.execution_step.add().CopyFrom(step.Proto())
1982 add_nets_in_order(step, net_list)
1983 self.
AddNets([step.get_net(n)
for n
in net_list])
1987 Return the list of all attributes under the given `name`, present in 1988 all of the nets used in this plan. 1992 objs += net.get_attributes(name)
1996 def to_execution_step(step_or_nets, default_name=None):
1997 from caffe2.python.net_builder
import NetBuilder
1998 if isinstance(step_or_nets, ExecutionStep):
2002 if not default_name
and hasattr(step_or_nets,
'name'):
2003 default_name = step_or_nets.name
2004 if isinstance(step_or_nets, NetBuilder):
2005 stop_blob = step_or_nets._stop_blob
2006 step_or_nets = step_or_nets.get()
2008 default_name, step_or_nets, should_stop_blob=stop_blob)
2015 report_interval=None,
2016 concurrent_substeps=None,
2017 should_stop_blob=None,
2020 Helper for creating an ExecutionStep. 2021 - steps_or_nets can be: 2026 - list<ExecutionStep> 2027 - should_stop_blob is either None or a scalar boolean blob. 2028 - This blob is checked AFTER every substeps/subnets. 2029 - If specified and true, then this step will return immediately. 2030 - Be sure to handle race conditions if setting from concurrent threads. 2031 - if no should_stop_blob or num_iter is provided, defaults to num_iter=1 2033 assert should_stop_blob
is None or num_iter
is None, (
2034 'Cannot set both should_stop_blob and num_iter.')
2035 if should_stop_blob
is None and num_iter
is None:
2039 if should_stop_blob
is not None:
2040 step.SetShouldStopBlob(should_stop_blob)
2041 if num_iter
is not None:
2042 step.SetIter(num_iter)
2043 if only_once
is not None:
2044 step.SetOnlyOnce(only_once)
2045 if concurrent_substeps
is not None:
2046 step.SetConcurrentSubsteps(concurrent_substeps)
2047 if report_net
is not None:
2048 assert report_interval
is not None 2049 step.SetReportNet(report_net, report_interval)
2051 if isinstance(steps_or_nets, ExecutionStep):
2052 step.AddSubstep(steps_or_nets)
2053 elif isinstance(steps_or_nets, Net):
2054 step.AddNet(steps_or_nets)
2055 elif isinstance(steps_or_nets, list):
2056 if all(isinstance(x, Net)
for x
in steps_or_nets):
2057 map(step.AddNet, steps_or_nets)
2059 map(step.AddSubstep, map(to_execution_step, steps_or_nets))
2062 'steps_or_nets must be a step, a net, or a list of nets or steps.')
2067 """Same as execution_step() except that the step name is scoped.""" 2068 default_name =
ScopedName(name)
if name
else name
dictionary gradient_registry_
def GetBackwardPass(cls, operators, ys)
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
def AddExternalInput(self, inputs)
def output_to_list(op_output)
def _CreateAndAddToNet(self, op_type, inputs=None, args, kwargs)
def CreatePythonOperator(f, inputs, outputs, grad_f=None, pass_workspace=False, args, kwargs)
def _GetGradientForOpCC(cls, op_def, g_output)
def NewRecord(net, schema)
def external_inputs(self)
def ScopedBlobReference(name, args, kwargs)
def DoGradientAccumulation(self, fwd_op_idx)
def get_all_attributes(self, name)
def _RecreateLookupTables(self)
def _DisambiguateGradOpOutput(self, grad_op, idx, cnt)
def scoped_execution_step(name, args, kwargs)
def _MakeSparseSumOps(self, generators, out_base_name)
def _MakeSumOps(self, input_name, input_version)
def MakeArgument(key, value)
def NextScopedBlob(self, prefix='unnamed')
def _SetSumOpsDeviceOption(self, sum_ops, generators)
def BuildGradientGenerators(self, fwd_op_idx, gradient_ops, g_output, g_input)
def AppendSparseGenerators(self, sparse_generators)
def _assert_can_mutate(self)
def _GenerateGradientsForForwardOp(self, forward_op_idx, input_to_grad)
def AddExternalOutput(self, outputs)
def _CreateAndAddToSelf(self, op_type, inputs, outputs=None, kwargs)
def CheckGradientOperatorInput(self, grad_op_input, g_output, fwd_op_idx, locally_generated_blobs)
def get_undefined_blobs(ssa)
def RunAllOnGPU(self, gpu_id=0, use_cudnn=False)
def get_ssa(net, blob_versions=None)
def Clone(self, name, blob_remap=None, op_id_mask=None, remap_funcs=None, keep_schema=True)
def NextBlob(self, prefix='unnamed')
def GetBlobRef(self, blob_name)
def _VerifyGradientGenerators(self, generator)
def GetIndexFromGradientList(g_list, name)
def __getattr__(self, op_type)
def _ExtendOps(self, new_ops)
def _MakeDenseSumOps(self, generators, out_base_name)
def NextName(self, prefix=None, output_id=None)
def set_rand_seed(self, seed=100, sequence_seed=True, seed_on_op_def=False)
def clone_and_bind_net(net, name, prefix, blob_remap=None, inputs=None, keep_schema=True)
def RunEveryMillis(self, interval)
def BlobIsDefined(self, blob)
def _GetInitGradients(self, ys)
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
def SetReportNet(self, report_net, report_interval)
def get_all_attributes(self, name)
def RegisterGradient(cls, op_type)
def GetBackwardPass(self, ys)
def ClonePartial(self, name, inputs, outputs, remap_funcs=None)
def __init__(self, name, net=None)
def from_blob_list(schema, values)
def AddGradientOperators(self, ys, skip=0)
def add_attribute(self, name, obj)
def get_output_producers(ssa)
def RunOperatorImmediate(op)
def Python(self, f, grad_f=None, pass_workspace=False)
def _InvalidateLookupTables(self)
def get_op_ids_in_path(ssa, blob_versions, inputs, outputs)
def _GetSumOpOutputName(self, generator, input_name)
def worker_init_func(func)
def _CheckSumOpsConflict(self, out_base_name, g)
def get_attributes(self, name)
def __init__(self, name_or_proto)