Caffe2 - Python API
A deep learning, cross platform ML framework
net_drawer.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import argparse
8 import json
9 import logging
10 from collections import defaultdict
11 from caffe2.python import utils
12 
13 logger = logging.getLogger(__name__)
14 logger.setLevel(logging.INFO)
15 
16 try:
17  import pydot
18 except ImportError:
19  logger.info(
20  'Cannot import pydot, which is required for drawing a network. This '
21  'can usually be installed in python with "pip install pydot". Also, '
22  'pydot requires graphviz to convert dot files to pdf: in ubuntu, this '
23  'can usually be installed with "sudo apt-get install graphviz".'
24  )
25  print(
26  'net_drawer will not run correctly. Please install the correct '
27  'dependencies.'
28  )
29  pydot = None
30 
31 from caffe2.proto import caffe2_pb2
32 
33 OP_STYLE = {
34  'shape': 'box',
35  'color': '#0F9D58',
36  'style': 'filled',
37  'fontcolor': '#FFFFFF'
38 }
39 BLOB_STYLE = {'shape': 'octagon'}
40 
41 
42 def _rectify_operator_and_name(operators_or_net, name):
43  """Gets the operators and name for the pydot graph."""
44  if isinstance(operators_or_net, caffe2_pb2.NetDef):
45  operators = operators_or_net.op
46  if name is None:
47  name = operators_or_net.name
48  elif hasattr(operators_or_net, 'Proto'):
49  net = operators_or_net.Proto()
50  if not isinstance(net, caffe2_pb2.NetDef):
51  raise RuntimeError(
52  "Expecting NetDef, but got {}".format(type(net)))
53  operators = net.op
54  if name is None:
55  name = net.name
56  else:
57  operators = operators_or_net
58  if name is None:
59  name = "unnamed"
60  return operators, name
61 
62 
63 def _escape_label(name):
64  # json.dumps is poor man's escaping
65  return json.dumps(name)
66 
67 
68 def GetOpNodeProducer(append_output, **kwargs):
69  def ReallyGetOpNode(op, op_id):
70  if op.name:
71  node_name = '%s/%s (op#%d)' % (op.name, op.type, op_id)
72  else:
73  node_name = '%s (op#%d)' % (op.type, op_id)
74  if append_output:
75  for output_name in op.output:
76  node_name += '\n' + output_name
77  return pydot.Node(node_name, **kwargs)
78  return ReallyGetOpNode
79 
80 
81 def GetPydotGraph(
82  operators_or_net,
83  name=None,
84  rankdir='LR',
85  node_producer=None
86 ):
87  if node_producer is None:
88  node_producer = GetOpNodeProducer(False, **OP_STYLE)
89  operators, name = _rectify_operator_and_name(operators_or_net, name)
90  graph = pydot.Dot(name, rankdir=rankdir)
91  pydot_nodes = {}
92  pydot_node_counts = defaultdict(int)
93  for op_id, op in enumerate(operators):
94  op_node = node_producer(op, op_id)
95  graph.add_node(op_node)
96  # print 'Op: %s' % op.name
97  # print 'inputs: %s' % str(op.input)
98  # print 'outputs: %s' % str(op.output)
99  for input_name in op.input:
100  if input_name not in pydot_nodes:
101  input_node = pydot.Node(
102  _escape_label(
103  input_name + str(pydot_node_counts[input_name])),
104  label=_escape_label(input_name),
105  **BLOB_STYLE
106  )
107  pydot_nodes[input_name] = input_node
108  else:
109  input_node = pydot_nodes[input_name]
110  graph.add_node(input_node)
111  graph.add_edge(pydot.Edge(input_node, op_node))
112  for output_name in op.output:
113  if output_name in pydot_nodes:
114  # we are overwriting an existing blob. need to updat the count.
115  pydot_node_counts[output_name] += 1
116  output_node = pydot.Node(
117  _escape_label(
118  output_name + str(pydot_node_counts[output_name])),
119  label=_escape_label(output_name),
120  **BLOB_STYLE
121  )
122  pydot_nodes[output_name] = output_node
123  graph.add_node(output_node)
124  graph.add_edge(pydot.Edge(op_node, output_node))
125  return graph
126 
127 
129  operators_or_net,
130  name=None,
131  rankdir='LR',
132  minimal_dependency=False,
133  node_producer=None,
134 ):
135  """Different from GetPydotGraph, hide all blob nodes and only show op nodes.
136 
137  If minimal_dependency is set as well, for each op, we will only draw the
138  edges to the minimal necessary ancestors. For example, if op c depends on
139  op a and b, and op b depends on a, then only the edge b->c will be drawn
140  because a->c will be implied.
141  """
142  if node_producer is None:
143  node_producer = GetOpNodeProducer(False, **OP_STYLE)
144  operators, name = _rectify_operator_and_name(operators_or_net, name)
145  graph = pydot.Dot(name, rankdir=rankdir)
146  # blob_parents maps each blob name to its generating op.
147  blob_parents = {}
148  # op_ancestry records the ancestors of each op.
149  op_ancestry = defaultdict(set)
150  for op_id, op in enumerate(operators):
151  op_node = node_producer(op, op_id)
152  graph.add_node(op_node)
153  # Get parents, and set up op ancestry.
154  parents = [
155  blob_parents[input_name] for input_name in op.input
156  if input_name in blob_parents
157  ]
158  op_ancestry[op_node].update(parents)
159  for node in parents:
160  op_ancestry[op_node].update(op_ancestry[node])
161  if minimal_dependency:
162  # only add nodes that do not have transitive ancestry
163  for node in parents:
164  if all(
165  [node not in op_ancestry[other_node]
166  for other_node in parents]
167  ):
168  graph.add_edge(pydot.Edge(node, op_node))
169  else:
170  # Add all parents to the graph.
171  for node in parents:
172  graph.add_edge(pydot.Edge(node, op_node))
173  # Update blob_parents to reflect that this op created the blobs.
174  for output_name in op.output:
175  blob_parents[output_name] = op_node
176  return graph
177 
178 
179 def GetOperatorMapForPlan(plan_def):
180  operator_map = {}
181  for net_id, net in enumerate(plan_def.network):
182  if net.HasField('name'):
183  operator_map[plan_def.name + "_" + net.name] = net.op
184  else:
185  operator_map[plan_def.name + "_network_%d" % net_id] = net.op
186  return operator_map
187 
188 
189 def _draw_nets(nets, g):
190  nodes = []
191  for i, net in enumerate(nets):
192  nodes.append(pydot.Node(_escape_label(net)))
193  g.add_node(nodes[-1])
194  if i > 0:
195  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
196  return nodes
197 
198 
199 def _draw_steps(steps, g, skip_step_edges=False): # noqa
200  kMaxParallelSteps = 3
201 
202  def get_label():
203  label = [step.name + '\n']
204  if step.report_net:
205  label.append('Reporter: {}'.format(step.report_net))
206  if step.should_stop_blob:
207  label.append('Stopper: {}'.format(step.should_stop_blob))
208  if step.concurrent_substeps:
209  label.append('Concurrent')
210  if step.only_once:
211  label.append('Once')
212  return '\n'.join(label)
213 
214  def substep_edge(start, end):
215  return pydot.Edge(start, end, arrowhead='dot', style='dashed')
216 
217  nodes = []
218  for i, step in enumerate(steps):
219  parallel = step.concurrent_substeps
220 
221  nodes.append(pydot.Node(_escape_label(get_label()), **OP_STYLE))
222  g.add_node(nodes[-1])
223 
224  if i > 0 and not skip_step_edges:
225  g.add_edge(pydot.Edge(nodes[-2], nodes[-1]))
226 
227  if step.network:
228  sub_nodes = _draw_nets(step.network, g)
229  elif step.substep:
230  if parallel:
231  sub_nodes = _draw_steps(
232  step.substep[:kMaxParallelSteps], g, skip_step_edges=True)
233  else:
234  sub_nodes = _draw_steps(step.substep, g)
235  else:
236  raise ValueError('invalid step')
237 
238  if parallel:
239  for sn in sub_nodes:
240  g.add_edge(substep_edge(nodes[-1], sn))
241  if len(step.substep) > kMaxParallelSteps:
242  ellipsis = pydot.Node('{} more steps'.format(
243  len(step.substep) - kMaxParallelSteps), **OP_STYLE)
244  g.add_node(ellipsis)
245  g.add_edge(substep_edge(nodes[-1], ellipsis))
246  else:
247  g.add_edge(substep_edge(nodes[-1], sub_nodes[0]))
248 
249  return nodes
250 
251 
252 def GetPlanGraph(plan_def, name=None, rankdir='TB'):
253  graph = pydot.Dot(name, rankdir=rankdir)
254  _draw_steps(plan_def.execution_step, graph)
255  return graph
256 
257 
258 def GetGraphInJson(operators_or_net, output_filepath):
259  operators, _ = _rectify_operator_and_name(operators_or_net, None)
260  blob_strid_to_node_id = {}
261  node_name_counts = defaultdict(int)
262  nodes = []
263  edges = []
264  for op_id, op in enumerate(operators):
265  op_label = op.name + '/' + op.type if op.name else op.type
266  op_node_id = len(nodes)
267  nodes.append({
268  'id': op_node_id,
269  'label': op_label,
270  'op_id': op_id,
271  'type': 'op'
272  })
273  for input_name in op.input:
274  strid = _escape_label(
275  input_name + str(node_name_counts[input_name]))
276  if strid not in blob_strid_to_node_id:
277  input_node = {
278  'id': len(nodes),
279  'label': input_name,
280  'type': 'blob'
281  }
282  blob_strid_to_node_id[strid] = len(nodes)
283  nodes.append(input_node)
284  else:
285  input_node = nodes[blob_strid_to_node_id[strid]]
286  edges.append({
287  'source': blob_strid_to_node_id[strid],
288  'target': op_node_id
289  })
290  for output_name in op.output:
291  strid = _escape_label(
292  output_name + str(node_name_counts[output_name]))
293  if strid in blob_strid_to_node_id:
294  # we are overwriting an existing blob. need to update the count.
295  node_name_counts[output_name] += 1
296  strid = _escape_label(
297  output_name + str(node_name_counts[output_name]))
298 
299  if strid not in blob_strid_to_node_id:
300  output_node = {
301  'id': len(nodes),
302  'label': output_name,
303  'type': 'blob'
304  }
305  blob_strid_to_node_id[strid] = len(nodes)
306  nodes.append(output_node)
307  edges.append({
308  'source': op_node_id,
309  'target': blob_strid_to_node_id[strid]
310  })
311 
312  with open(output_filepath, 'w') as f:
313  json.dump({'nodes': nodes, 'edges': edges}, f)
314 
315 
316 # A dummy minimal PNG image used by GetGraphPngSafe as a
317 # placeholder when rendering fail to run.
318 _DummyPngImage = (
319  b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00'
320  b'\x01\x01\x00\x00\x00\x007n\xf9$\x00\x00\x00\nIDATx\x9cc`\x00\x00'
321  b'\x00\x02\x00\x01H\xaf\xa4q\x00\x00\x00\x00IEND\xaeB`\x82')
322 
323 
324 def GetGraphPngSafe(func, *args, **kwargs):
325  """
326  Invokes `func` (e.g. GetPydotGraph) with args. If anything fails - returns
327  and empty image instead of throwing Exception
328  """
329  try:
330  graph = func(*args, **kwargs)
331  if not isinstance(graph, pydot.Dot):
332  raise ValueError("func is expected to return pydot.Dot")
333  return graph.create_png()
334  except Exception as e:
335  logger.error("Failed to draw graph: {}".format(e))
336  return _DummyPngImage
337 
338 
339 def main():
340  parser = argparse.ArgumentParser(description="Caffe2 net drawer.")
341  parser.add_argument(
342  "--input",
343  type=str,
344  help="The input protobuf file."
345  )
346  parser.add_argument(
347  "--output_prefix",
348  type=str, default="",
349  help="The prefix to be added to the output filename."
350  )
351  parser.add_argument(
352  "--minimal", action="store_true",
353  help="If set, produce a minimal visualization."
354  )
355  parser.add_argument(
356  "--minimal_dependency", action="store_true",
357  help="If set, only draw minimal dependency."
358  )
359  parser.add_argument(
360  "--append_output", action="store_true",
361  help="If set, append the output blobs to the operator names.")
362  parser.add_argument(
363  "--rankdir", type=str, default="LR",
364  help="The rank direction of the pydot graph."
365  )
366  args = parser.parse_args()
367  with open(args.input, 'r') as fid:
368  content = fid.read()
370  content, {
371  caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x),
372  caffe2_pb2.NetDef: lambda x: {x.name: x.op},
373  }
374  )
375  for key, operators in graphs.items():
376  if args.minimal:
377  graph = GetPydotGraphMinimal(
378  operators,
379  name=key,
380  rankdir=args.rankdir,
381  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE),
382  minimal_dependency=args.minimal_dependency)
383  else:
384  graph = GetPydotGraph(
385  operators,
386  name=key,
387  rankdir=args.rankdir,
388  node_producer=GetOpNodeProducer(args.append_output, **OP_STYLE))
389  filename = args.output_prefix + graph.get_name() + '.dot'
390  graph.write(filename, format='raw')
391  pdf_filename = filename[:-3] + 'pdf'
392  try:
393  graph.write_pdf(pdf_filename)
394  except Exception:
395  print(
396  'Error when writing out the pdf file. Pydot requires graphviz '
397  'to convert dot files to pdf, and you may not have installed '
398  'graphviz. On ubuntu this can usually be installed with "sudo '
399  'apt-get install graphviz". We have generated the .dot file '
400  'but will not be able to generate pdf file for now.'
401  )
402 
403 
404 if __name__ == '__main__':
405  main()
def GetPydotGraphMinimal(operators_or_net, name=None, rankdir='LR', minimal_dependency=False, node_producer=None)
Definition: net_drawer.py:134
def GetContentFromProtoString(s, function_map)
Definition: utils.py:136
def GetGraphPngSafe(func, args, kwargs)
Definition: net_drawer.py:324