Caffe2 - Python API
A deep learning, cross platform ML framework
net_printer.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.proto.caffe2_pb2 import OperatorDef
9 from caffe2.python.checkpoint import Job
10 from caffe2.python.core import Net, ExecutionStep, Plan
11 from caffe2.python.task import Task, TaskGroup, WorkspaceType, TaskOutput
12 from collections import defaultdict
13 from contextlib import contextmanager
14 from copy import copy
15 
16 
17 class Visitor(object):
18  @classmethod
19  def register(cls, Type):
20  if not(hasattr(cls, 'visitors')):
21  cls.visitors = []
22 
23  def _register(func):
24  cls.visitors.append((Type, func))
25  return func
26 
27  return _register
28 
29  def __call__(self, obj, *args, **kwargs):
30  if obj is None:
31  return
32  for Type, func in self.__class__.visitors:
33  if isinstance(obj, Type):
34  return func(self, obj, *args, **kwargs)
35  raise TypeError('%s: unsupported object type: %s' % (
36  self.__class__.__name__, type(obj)))
37 
38 
40  PREFIXES_TO_IGNORE = {'distributed_ctx_init'}
41 
42  def __init__(self):
43  self.workspaces = defaultdict(lambda: defaultdict(lambda: 0))
44  self.workspace_ctx = []
45 
46  @property
47  def workspace(self):
48  return self.workspace_ctx[-1]
49 
50  @contextmanager
51  def set_workspace(self, node=None, ws=None, do_copy=False):
52  if ws is not None:
53  ws = ws
54  elif node is not None:
55  ws = self.workspaces[str(node)]
56  else:
57  ws = self.workspace
58  if do_copy:
59  ws = copy(ws)
60  self.workspace_ctx.append(ws)
61  yield ws
62  del self.workspace_ctx[-1]
63 
64  def define_blob(self, blob):
65  self.workspace[blob] += 1
66 
67  def need_blob(self, blob):
68  if any(blob.startswith(p) for p in Analyzer.PREFIXES_TO_IGNORE):
69  return
70  assert blob in self.workspace, 'Blob undefined: %s' % blob
71 
72 
73 @Analyzer.register(OperatorDef)
74 def analyze_op(analyzer, op):
75  map(analyzer.need_blob, op.input)
76  map(analyzer.define_blob, op.output)
77 
78 
79 @Analyzer.register(Net)
80 def analyze_net(analyzer, net):
81  map(analyzer, net.Proto().op)
82 
83 
84 @Analyzer.register(ExecutionStep)
85 def analyze_step(analyzer, step):
86  proto = step.Proto()
87  if proto.report_net:
88  with analyzer.set_workspace(do_copy=True):
89  analyzer(step.get_net(proto.report_net))
90  all_new_blobs = set()
91  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
92  for substep in substeps:
93  with analyzer.set_workspace(do_copy=proto.concurrent_substeps) as ws_in:
94  analyzer(substep)
95  if proto.should_stop_blob:
96  analyzer.need_blob(proto.should_stop_blob)
97  if proto.concurrent_substeps:
98  new_blobs = set(ws_in.keys()) - set(analyzer.workspace.keys())
99  assert len(all_new_blobs & new_blobs) == 0, (
100  'Error: Blobs created by multiple parallel steps: %s' % (
101  ', '.join(all_new_blobs & new_blobs)))
102  all_new_blobs |= new_blobs
103  map(analyzer.define_blob, all_new_blobs)
104 
105 
106 @Analyzer.register(Task)
107 def analyze_task(analyzer, task):
108  # check that our plan protobuf is not too large (limit of 64Mb)
109  step = task.get_step()
110  plan = Plan(task.node)
111  plan.AddStep(step)
112  proto_len = len(plan.Proto().SerializeToString())
113  assert proto_len < 2 ** 26, (
114  'Due to a protobuf limitation, serialized tasks must be smaller '
115  'than 64Mb, but this task has {} bytes.' % proto_len)
116 
117  is_private = task.workspace_type() != WorkspaceType.GLOBAL
118  with analyzer.set_workspace(do_copy=is_private):
119  analyzer(step)
120 
121 
122 @Analyzer.register(TaskGroup)
123 def analyze_task_group(analyzer, tg):
124  for task in tg.tasks_by_node().tasks():
125  with analyzer.set_workspace(node=task.node):
126  analyzer(task)
127 
128 
129 @Analyzer.register(Job)
130 def analyze_job(analyzer, job):
131  analyzer(job.init_group)
132  analyzer(job.epoch_group)
133 
134 
135 def analyze(obj):
136  """
137  Given a Job, visits all the execution steps making sure that:
138  - no undefined blobs will be found during excution
139  - no blob with same name is defined in concurrent steps
140  """
141  Analyzer()(obj)
142 
143 
144 class Text(object):
145  def __init__(self):
146  self._indent = 0
147  self._lines_in_context = [0]
148  self.lines = []
149 
150  @contextmanager
151  def context(self, text):
152  if text is not None:
153  self.add('with %s:' % text)
154  self._indent += 4
155  self._lines_in_context.append(0)
156  yield
157  if text is not None:
158  if self._lines_in_context[-1] == 0:
159  self.add('pass')
160  self._indent -= 4
161  del self._lines_in_context[-1]
162 
163  def add(self, text):
164  self._lines_in_context[-1] += 1
165  self.lines.append((' ' * self._indent) + text)
166 
167  def __str__(self):
168  return '\n'.join(self.lines)
169 
170 
172  def __init__(self, factor_prefixes=False):
173  super(Visitor, self).__init__()
174  super(Text, self).__init__()
175  self.factor_prefixes = factor_prefixes
176 
177 
178 def _sanitize_str(s):
179  s = str(s)
180  return s if len(s) < 64 else (s[:64] + '...<+len=%d>' % (len(s) - 64))
181 
182 
183 def _arg_val(arg):
184  if arg.HasField('f'):
185  return str(arg.f)
186  if arg.HasField('i'):
187  return str(arg.i)
188  if arg.HasField('s'):
189  return _sanitize_str(arg.s)
190  if arg.floats:
191  return str(list(arg.floats))
192  if arg.ints:
193  return str(list(arg.ints))
194  if arg.strings:
195  return str([_sanitize_str(s) for s in arg.strings])
196  return '[]'
197 
198 
199 def commonprefix(m):
200  "Given a list of strings, returns the longest common prefix"
201  if not m:
202  return ''
203  s1 = min(m)
204  s2 = max(m)
205  for i, c in enumerate(s1):
206  if c != s2[i]:
207  return s1[:i]
208  return s1
209 
210 
211 def factor_prefix(vals, do_it):
212  vals = map(str, vals)
213  prefix = commonprefix(vals) if len(vals) > 1 and do_it else ''
214  joined = ', '.join(v[len(prefix):] for v in vals)
215  return '%s[%s]' % (prefix, joined) if prefix else joined
216 
217 
218 def call(op, inputs=None, outputs=None, factor_prefixes=False):
219  if not inputs:
220  inputs = ''
221  else:
222  inputs_v = [a for a in inputs if not isinstance(a, tuple)]
223  inputs_kv = [a for a in inputs if isinstance(a, tuple)]
224  inputs = ', '.join(filter(
225  bool,
226  [factor_prefix(inputs_v, factor_prefixes)] +
227  ['%s=%s' % kv for kv in inputs_kv]))
228  call = '%s(%s)' % (op, inputs)
229  return call if not outputs else '%s = %s' % (
230  factor_prefix(outputs, factor_prefixes), call)
231 
232 
233 @Printer.register(OperatorDef)
234 def print_op(text, op):
235  text.add(call(
236  op.type,
237  list(op.input) + [(a.name, _arg_val(a)) for a in op.arg],
238  op.output,
239  factor_prefixes=text.factor_prefixes))
240 
241 
242 @Printer.register(Net)
243 def print_net(text, net):
244  text.add('# net: %s' % str(net))
245  for op in net.Proto().op:
246  text(op)
247 
248 
249 def _get_step_context(step):
250  proto = step.Proto()
251  if proto.should_stop_blob:
252  return call('loop'), False
253  if proto.num_iter and proto.num_iter != 1:
254  return call('loop', [proto.num_iter]), False
255  concurrent = proto.concurrent_substeps and len(step.Substeps()) > 1
256  if concurrent:
257  return call('parallel'), True
258  if proto.report_net:
259  return call('run_once'), False
260  return None, False
261 
262 
263 @Printer.register(ExecutionStep)
264 def print_step(text, step):
265  proto = step.Proto()
266  step_ctx, do_substep = _get_step_context(step)
267  with text.context(step_ctx):
268  if proto.report_net:
269  with text.context(call('report_net', [proto.report_interval])):
270  text(step.get_net(proto.report_net))
271  substeps = step.Substeps() + [step.get_net(n) for n in proto.network]
272  for substep in substeps:
273  if (isinstance(substep, ExecutionStep) and
274  substep.Proto().run_every_ms):
275  substep_ctx = call(
276  'reporter',
277  [str(substep), ('interval_ms', substep.Proto().run_every_ms)])
278  elif do_substep:
279  substep_ctx = call('step', [str(substep)])
280  else:
281  substep_ctx = None
282  with text.context(substep_ctx):
283  text(substep)
284  if proto.should_stop_blob:
285  text.add(call('yield stop_if', [proto.should_stop_blob]))
286 
287 
288 def _print_task_output(x):
289  assert isinstance(x, TaskOutput)
290  return 'Output[' + ', '.join(map(str, x.names)) + ']'
291 
292 
293 @Printer.register(Task)
294 def print_task(text, task):
295  outs = ', '.join(map(_print_task_output, task.outputs()))
296  context = [('node', task.node), ('name', task.name), ('outputs', outs)]
297  with text.context(call('Task', context)):
298  text(task.get_step())
299 
300 
301 @Printer.register(TaskGroup)
302 def print_task_group(text, tg, header=None):
303  with text.context(header or call('TaskGroup')):
304  for task in tg.tasks_by_node().tasks():
305  text(task)
306 
307 
308 @Printer.register(Job)
309 def print_job(text, job):
310  text(job.init_group, 'Job.current().init_group')
311  text(job.epoch_group, 'Job.current().epoch_group')
312  with text.context('Job.current().stop_signals'):
313  for out in job.stop_signals:
314  text.add(_print_task_output(out))
315  text(job.exit_group, 'Job.current().exit_group')
316 
317 
318 def to_string(obj):
319  """
320  Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string
321  with detailed description of the execution steps.
322  """
323  printer = Printer()
324  printer(obj)
325  return str(printer)
326 
327 
328 def debug_net(net):
329  """
330  Given a Net, produce another net that logs info about the operator call
331  before each operator execution. Use for debugging purposes.
332  """
333  assert isinstance(net, Net)
334  debug_net = Net(str(net))
335  assert isinstance(net, Net)
336  for op in net.Proto().op:
337  text = Text()
338  print_op(op, text)
339  debug_net.LogInfo(str(text))
340  debug_net.Proto().op.extend([op])
341  return debug_net
def debug_net(net)
Definition: net_printer.py:328
def analyze(obj)
Definition: net_printer.py:135
def add(self, text)
Definition: net_printer.py:163
Module caffe2.python.workspace.
Definition: workspace.py:1
def to_string(obj)
Definition: net_printer.py:318
Module caffe2.python.context.
Definition: context.py:1