3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.proto.caffe2_pb2
import OperatorDef
9 from caffe2.python.checkpoint
import Job
10 from caffe2.python.core
import Net, ExecutionStep, Plan
11 from caffe2.python.task
import Task, TaskGroup, WorkspaceType, TaskOutput
12 from collections
import defaultdict
13 from contextlib
import contextmanager
19 def register(cls, Type):
20 if not(hasattr(cls,
'visitors')):
29 def __call__(self, obj, *args, **kwargs):
32 for Type, func
in self.__class__.visitors:
33 if isinstance(obj, Type):
34 return func(self, obj, *args, **kwargs)
35 raise TypeError(
'%s: unsupported object type: %s' % (
36 self.__class__.__name__, type(obj)))
40 PREFIXES_TO_IGNORE = {
'distributed_ctx_init'}
43 self.
workspaces = defaultdict(
lambda: defaultdict(
lambda: 0))
51 def set_workspace(self, node=None, ws=None, do_copy=False):
54 elif node
is not None:
64 def define_blob(self, blob):
67 def need_blob(self, blob):
68 if any(blob.startswith(p)
for p
in Analyzer.PREFIXES_TO_IGNORE):
70 assert blob
in self.
workspace,
'Blob undefined: %s' % blob
73 @Analyzer.register(OperatorDef)
74 def analyze_op(analyzer, op):
75 map(analyzer.need_blob, op.input)
76 map(analyzer.define_blob, op.output)
79 @Analyzer.register(Net)
80 def analyze_net(analyzer, net):
81 map(analyzer, net.Proto().op)
84 @Analyzer.register(ExecutionStep)
85 def analyze_step(analyzer, step):
88 with analyzer.set_workspace(do_copy=
True):
89 analyzer(step.get_net(proto.report_net))
91 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
92 for substep
in substeps:
93 with analyzer.set_workspace(do_copy=proto.concurrent_substeps)
as ws_in:
95 if proto.should_stop_blob:
96 analyzer.need_blob(proto.should_stop_blob)
97 if proto.concurrent_substeps:
98 new_blobs = set(ws_in.keys()) - set(analyzer.workspace.keys())
99 assert len(all_new_blobs & new_blobs) == 0, (
100 'Error: Blobs created by multiple parallel steps: %s' % (
101 ', '.join(all_new_blobs & new_blobs)))
102 all_new_blobs |= new_blobs
103 map(analyzer.define_blob, all_new_blobs)
106 @Analyzer.register(Task)
107 def analyze_task(analyzer, task):
109 step = task.get_step()
110 plan = Plan(task.node)
112 proto_len = len(plan.Proto().SerializeToString())
113 assert proto_len < 2 ** 26, (
114 'Due to a protobuf limitation, serialized tasks must be smaller ' 115 'than 64Mb, but this task has {} bytes.' % proto_len)
117 is_private = task.workspace_type() != WorkspaceType.GLOBAL
118 with analyzer.set_workspace(do_copy=is_private):
122 @Analyzer.register(TaskGroup)
123 def analyze_task_group(analyzer, tg):
124 for task
in tg.tasks_by_node().tasks():
125 with analyzer.set_workspace(node=task.node):
129 @Analyzer.register(Job)
130 def analyze_job(analyzer, job):
131 analyzer(job.init_group)
132 analyzer(job.epoch_group)
137 Given a Job, visits all the execution steps making sure that: 138 - no undefined blobs will be found during excution 139 - no blob with same name is defined in concurrent steps 153 self.
add(
'with %s:' % text)
168 return '\n'.join(self.
lines)
172 def __init__(self, factor_prefixes=False):
173 super(Visitor, self).__init__()
174 super(Text, self).__init__()
178 def _sanitize_str(s):
180 return s
if len(s) < 64
else (s[:64] +
'...<+len=%d>' % (len(s) - 64))
184 if arg.HasField(
'f'):
186 if arg.HasField(
'i'):
188 if arg.HasField(
's'):
189 return _sanitize_str(arg.s)
191 return str(list(arg.floats))
193 return str(list(arg.ints))
195 return str([_sanitize_str(s)
for s
in arg.strings])
200 "Given a list of strings, returns the longest common prefix" 205 for i, c
in enumerate(s1):
211 def factor_prefix(vals, do_it):
212 vals = map(str, vals)
213 prefix = commonprefix(vals)
if len(vals) > 1
and do_it
else '' 214 joined =
', '.join(v[len(prefix):]
for v
in vals)
215 return '%s[%s]' % (prefix, joined)
if prefix
else joined
218 def call(op, inputs=None, outputs=None, factor_prefixes=False):
222 inputs_v = [a
for a
in inputs
if not isinstance(a, tuple)]
223 inputs_kv = [a
for a
in inputs
if isinstance(a, tuple)]
224 inputs =
', '.join(filter(
226 [factor_prefix(inputs_v, factor_prefixes)] +
227 [
'%s=%s' % kv
for kv
in inputs_kv]))
228 call =
'%s(%s)' % (op, inputs)
229 return call
if not outputs
else '%s = %s' % (
230 factor_prefix(outputs, factor_prefixes), call)
233 @Printer.register(OperatorDef)
234 def print_op(text, op):
237 list(op.input) + [(a.name, _arg_val(a))
for a
in op.arg],
239 factor_prefixes=text.factor_prefixes))
242 @Printer.register(Net)
243 def print_net(text, net):
244 text.add(
'# net: %s' % str(net))
245 for op
in net.Proto().op:
249 def _get_step_context(step):
251 if proto.should_stop_blob:
252 return call(
'loop'),
False 253 if proto.num_iter
and proto.num_iter != 1:
254 return call(
'loop', [proto.num_iter]),
False 255 concurrent = proto.concurrent_substeps
and len(step.Substeps()) > 1
257 return call(
'parallel'),
True 259 return call(
'run_once'),
False 263 @Printer.register(ExecutionStep)
264 def print_step(text, step):
266 step_ctx, do_substep = _get_step_context(step)
267 with text.context(step_ctx):
269 with text.context(call(
'report_net', [proto.report_interval])):
270 text(step.get_net(proto.report_net))
271 substeps = step.Substeps() + [step.get_net(n)
for n
in proto.network]
272 for substep
in substeps:
273 if (isinstance(substep, ExecutionStep)
and 274 substep.Proto().run_every_ms):
277 [str(substep), (
'interval_ms', substep.Proto().run_every_ms)])
279 substep_ctx = call(
'step', [str(substep)])
282 with text.context(substep_ctx):
284 if proto.should_stop_blob:
285 text.add(call(
'yield stop_if', [proto.should_stop_blob]))
288 def _print_task_output(x):
289 assert isinstance(x, TaskOutput)
290 return 'Output[' +
', '.join(map(str, x.names)) +
']' 293 @Printer.register(Task)
294 def print_task(text, task):
295 outs =
', '.join(map(_print_task_output, task.outputs()))
296 context = [(
'node', task.node), (
'name', task.name), (
'outputs', outs)]
297 with text.context(call(
'Task', context)):
298 text(task.get_step())
301 @Printer.register(TaskGroup)
302 def print_task_group(text, tg, header=None):
303 with text.context(header
or call(
'TaskGroup')):
304 for task
in tg.tasks_by_node().tasks():
308 @Printer.register(Job)
309 def print_job(text, job):
310 text(job.init_group,
'Job.current().init_group')
311 text(job.epoch_group,
'Job.current().epoch_group')
312 with text.context(
'Job.current().stop_signals'):
313 for out
in job.stop_signals:
314 text.add(_print_task_output(out))
315 text(job.exit_group,
'Job.current().exit_group')
320 Given a Net, ExecutionStep, Task, TaskGroup or Job, produces a string 321 with detailed description of the execution steps. 330 Given a Net, produce another net that logs info about the operator call 331 before each operator execution. Use for debugging purposes. 333 assert isinstance(net, Net)
334 debug_net = Net(str(net))
335 assert isinstance(net, Net)
336 for op
in net.Proto().op:
339 debug_net.LogInfo(str(text))
340 debug_net.Proto().op.extend([op])
Module caffe2.python.workspace.
Module caffe2.python.context.