Caffe2 - Python API
A deep learning, cross platform ML framework
generator.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 import os
8 from caffe2.python import core, workspace
9 from caffe2.python.docs.formatter import Markdown
10 
11 OpSchema = workspace.C.OpSchema
12 
13 
14 class DocUploader(object):
15  def __init__(self):
16  pass
17 
18  def upload(self, text):
19  pass
20 
21 
22 class DocGenerator(object):
23  def __init__(self, formatter, uploader):
24  self.formatter = formatter
25  self.uploader = uploader
26  self.content_body = ""
27 
28  def create_body(self):
29  pass
30 
31  def update(self):
32  self.uploader.upload(self.content_body)
33 
34 
36  def getOperatorDoc(self, name, schema, priority):
37  return OperatorDoc(name, schema, priority)
38 
39  def getOperatorEngine(self, name):
40  return OperatorEngine(name)
41 
42  def getOperators(self):
43  # map: op_name -> operator
44  self.operators = {}
45  # map: op_name -> [engine, engine]
46  self.engines = {}
47 
48  def filePriority(x):
49  if x == "caffe2/caffe2/operators":
50  return 0
51  if 'contrib' in x.split('/'):
52  return 2
53  if 'experiments' in x.split('/'):
54  return 3
55  return 1
56 
57  for name in core._GetRegisteredOperators():
58  schema = OpSchema.get(name)
59  if schema:
60  priority = filePriority(os.path.dirname(schema.file))
61  operator = self.getOperatorDoc(name, schema, priority)
62  self.operators[name] = operator
63 
64  # Engine
65  elif name.find("_ENGINE_") != -1:
66  engine = self.getOperatorEngine(name)
67  if engine.base_op_name in self.engines:
68  self.engines[engine.base_op_name].append(engine)
69  else:
70  self.engines[engine.base_op_name] = [engine]
71 
72  # No schema
73  else:
74  priority = 4
75  self.operators[name] = self.getOperatorDoc(name, schema, priority)
76 
77  for name, engines in self.engines.items():
78  if name in self.operators:
79  self.operators[name].addEngines(engines)
80 
81  # Generate a sorted list of operators
82  operators = [v for k, v in self.operators.items()]
83 
84  def compare(op1, op2):
85  if op1.priority == op2.priority:
86  if op1.name < op2.name:
87  return -1
88  else:
89  return 1
90  return op1.priority - op2.priority
91 
92  return sorted(operators, cmp=compare)
93 
94  def createBody(self):
95  operators = self.getOperators()
96 
97  for operator in operators:
98  operator.generateSchema(self.formatter)
99 
100  self.content_body += self.formatter.dump()
101 
102 
103 class OperatorEngine(object):
104  def __init__(self, name):
105  self.op_name = name
106  self.base_op_name, self.engine = name.split("_ENGINE_", 1)
107 
108  def getDeviceImpl(self):
109  deviceImplList = []
110  for device, impl in {'CPU': OpSchema.get_cpu_impl(self.op_name),
111  'CUDA': OpSchema.get_cuda_impl(self.op_name)}.items():
112  if not impl:
113  continue
114  deviceImplList.append((device, impl))
115  return deviceImplList
116 
117  def generateDoc(self, formatter):
118  for device, impl in self.getDeviceImpl():
119  formatter.addLine(
120  '{engine} on {device}: {impl}'.format(engine=self.engine,
121  device=device,
122  impl=impl))
123 
124 
125 class OperatorDoc(object):
126  def __init__(self, name, schema, priority):
127  self.name = name
128  self.schema = schema
129  self.priority = priority
130  self.engines = []
131 
132  def addEngines(self, engines):
133  self.engines = engines
134 
135  def generateDoc(self, formatter):
136  if self.schema.doc:
137  formatter.parseAndAdd(self.schema.doc)
138  else:
139  formatter.addLine("No documentation yet.")
140 
141  def generateTable(self, formatter, tuples, title_row, title):
142  if tuples:
143  if title:
144  formatter.addHeader(title, 3)
145  table = []
146  if title_row:
147  table = [title_row]
148  for name, doc in tuples:
149  table.append([name, doc or ''])
150  formatter.addTable(table, (table == []))
151 
152  def generateInterface(self, formatter):
153  def makeDesc(title, desc):
154  f = formatter.clone()
155  f.addEmphasis(title, 1)
156  out = [(f.dump(), '')]
157  for name, doc in desc:
158  f = formatter.clone()
159  f.addCode(name, inline=True)
160  out.append((f.dump(), doc or ''))
161  return out
162 
163  tuples = []
164 
165  if self.schema.arg_desc:
166  tuples += makeDesc('Arguments', self.schema.arg_desc)
167 
168  if self.schema.input_desc:
169  tuples += makeDesc('Inputs', self.schema.input_desc)
170 
171  if self.schema.output_desc:
172  tuples += makeDesc('Outputs', self.schema.output_desc)
173 
174  self.generateTable(formatter, tuples, None, 'Interface')
175 
176  def generateCodeLink(self, formatter):
177  formatter.addHeader("Code", 3)
178  formatter.addCodeLink(self.schema.file)
179 
180  def getInfo(self, formatter, name, impl):
181  pass
182 
183  def generateDevices(self, formatter):
184  formatter.addHeader("Devices", 3)
185  devices = [
186  self.getInfo(formatter,
187  'CPU', OpSchema.get_cpu_impl(self.name)),
188  self.getInfo(formatter,
189  'GPU', OpSchema.get_cuda_impl(self.name)),
190  ]
191  formatter.addList([i for i in devices if i])
192 
193  def generateEngines(self, formatter):
194  if not len(self.engines):
195  return
196  formatter.addHeader("Engines", 3)
197  for engine in self.engines:
198  engine.generateDoc(formatter)
199 
200  def generateSchema(self, formatter):
201  formatter.addHeader(self.name, 2)
202  if self.schema:
203  self.generateDoc(formatter)
204  self.generateInterface(formatter)
205  self.generateCodeLink(formatter)
206  self.generateDevices(formatter)
207  self.generateEngines(formatter)
208  formatter.addBreak()
209  else:
210  formatter.addLine("No schema documented yet.")
211  self.generateDevices(formatter)
212 
213 
214 if __name__ == "__main__":
215  ops = OpDocGenerator(Markdown(), DocUploader())
216  ops.createBody()
217  print(ops.content_body)
def generateTable(self, formatter, tuples, title_row, title)
Definition: generator.py:141
def getOperatorEngine(self, name)
Definition: generator.py:39
def getOperatorDoc(self, name, schema, priority)
Definition: generator.py:36
def _GetRegisteredOperators()
Definition: core.py:51