Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_schema.cc
1 #include "caffe2/core/operator_schema.h"
2 
3 #include "caffe2/core/logging.h"
4 
5 namespace caffe2 {
6 
7 bool OpSchema::Verify(const OperatorDef& def) const {
8  // Check the number of inputs.
9  if (def.input_size() < min_input_ || def.input_size() > max_input_) {
10  LOG(ERROR) << "Input size " << def.input_size()
11  << " not in range [min=" << min_input_ << ", max="
12  << max_input_ << "].";
13  return false;
14  }
15  if (!num_inputs_allowed_(def.input_size())) {
16  LOG(ERROR) << "Input size " << def.input_size()
17  << " not in allowed input sizes.";
18  return false;
19  }
20  // Check the number of outputs.
21  if (def.output_size() < min_output_ || def.output_size() > max_output_) {
22  LOG(ERROR) << "Output size " << def.output_size()
23  << " not in range [min=" << min_output_ << ", max="
24  << max_output_ << "].";
25  return false;
26  }
27  if (!num_outputs_allowed_(def.output_size())) {
28  LOG(ERROR) << "Output size " << def.output_size()
29  << " not in allowed output sizes.";
30  return false;
31  }
32  if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
33  LOG(ERROR) << "Combination of input size " << def.input_size()
34  << "and output size " << def.output_size() << " not in allowed.";
35  return false;
36  }
37  // If the number of outputs can be calculated, check if the number matches.
38  if (calculate_output_) {
39  int expected_nout = calculate_output_(def.input_size());
40  if (expected_nout != kCannotComputeNumOutputs &&
41  def.output_size() != expected_nout) {
42  LOG(ERROR) << "Output size " << def.output_size()
43  << " not matching expected output size, which is "
44  << expected_nout;
45  return false;
46  }
47  }
48 
49  // Check in-place settings.
50  for (int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
51  for (int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
52  // If an input is the same as an output but in-place is not opt-in
53  // either as allowed or enforced, we will fail the verification.
54  if (def.input(in_idx) == def.output(out_idx) &&
55  (!inplace_allowed_(in_idx, out_idx)
56  && !inplace_enforced_(in_idx, out_idx))) {
57  LOG(ERROR) << "Input index " << in_idx << " and output idx " << out_idx
58  << " (" << def.input(in_idx) << ")"
59  << " are set to be in-place but this is actually not "
60  << "supported by op " << def.type();
61  return false;
62  }
63  if (def.input(in_idx) != def.output(out_idx) &&
64  inplace_enforced_(in_idx, out_idx)) {
65  LOG(ERROR) << "Input index " << in_idx << " (" << def.input(in_idx) << ")"
66  << " and output idx " << out_idx
67  << " (" << def.output(in_idx) << ")"
68  << " are not in-place but should be as required by op "
69  << def.type();
70  return false;
71  }
72  }
73  }
74 
75  // Phew. All verifications passed.
76  return true;
77 }
78 
79 OpSchema& OpSchema::NumInputs(int min, int max) {
80  min_input_ = min;
81  max_input_ = max;
82  return *this;
83 }
84 
86  return NumInputs(n, n);
87 }
88 
89 OpSchema& OpSchema::NumInputs(std::function<bool(int)> func) {
90  num_inputs_allowed_ = func;
91  return *this;
92 }
93 
94 OpSchema& OpSchema::NumInputs(set<int> allowed_input_nums) {
95  return NumInputs(
96  [allowed_input_nums](int n)->bool {
97  return allowed_input_nums.count(n);
98  });
99 }
100 
101 OpSchema& OpSchema::NumOutputs(int min, int max) {
102  min_output_ = min;
103  max_output_ = max;
104  return *this;
105 }
106 
108  return NumOutputs(n, n);
109 }
110 
111 OpSchema& OpSchema::NumOutputs(std::function<bool(int)> func) {
112  num_outputs_allowed_ = func;
113  return *this;
114 }
115 
116 OpSchema& OpSchema::NumOutputs(set<int> allowed_output_nums) {
117  return NumOutputs(
118  [allowed_output_nums](int n)->bool {
119  return allowed_output_nums.count(n);
120  });
121 }
122 
123 OpSchema& OpSchema::NumInputsOutputs(std::function<bool(int, int)> func) {
124  num_inputs_outputs_allowed_ = func;
125  return *this;
126 }
127 
128 OpSchema& OpSchema::OutputCalculator(std::function<int(int)> calc) {
129  calculate_output_ = calc;
130  return *this;
131 }
132 
134  return OutputCalculator([](int n)->int { return n; } );
135 }
136 
137 OpSchema& OpSchema::AllowInplace(std::function<bool(int, int)> inplace) {
138  inplace_allowed_ = inplace;
139  return *this;
140 }
141 
142 OpSchema& OpSchema::AllowInplace(set<std::pair<int, int>> inplace) {
143  return AllowInplace(
144  [inplace](int in, int out)->bool {
145  return inplace.count(std::make_pair(in, out));
146  });
147 }
148 
149 OpSchema& OpSchema::AllowOneToOneInplace() {
150  return AllowInplace([](int in, int out) { return in == out; });
151 }
152 
153 OpSchema& OpSchema::EnforceInplace(std::function<bool(int, int)> inplace) {
154  inplace_enforced_ = inplace;
155  return *this;
156 }
157 
158 OpSchema& OpSchema::EnforceInplace(set<std::pair<int, int>> inplace) {
159  return EnforceInplace(
160  [inplace](int in, int out)->bool {
161  return inplace.count(std::make_pair(in, out));
162  });
163 }
164 
165 OpSchema& OpSchema::EnforceOneToOneInplace() {
166  return EnforceInplace([](int in, int out) { return in == out; });
167 }
168 
170  TensorInferenceFunctionType function) {
171  tensor_inference_function_ = function;
172  return *this;
173 }
174 
177  [](const OperatorDef&, const vector<TensorShape>& input_types) {
178  return vector<TensorShape>(input_types);
179  });
180 }
181 
182 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(int idx) {
184  [idx](const OperatorDef&, const vector<TensorShape>& input_types) {
185  vector<TensorShape> out(1);
186  out[0] = input_types[idx];
187  return out;
188  });
189 }
190 
191 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(int idx, int dim) {
193  [idx, dim](const OperatorDef&, const vector<TensorShape>& input_types) {
194  vector<TensorShape> out(1);
195  out[0].add_dims(input_types[idx].dims(dim));
196  out[0].set_data_type(input_types[idx].data_type());
197  return out;
198  });
199 }
200 
201 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
203  [dt](const OperatorDef&, const vector<TensorShape>& input_types) {
204  vector<TensorShape> out(1);
205  out[0].set_data_type(dt);
206  return out;
207  });
208 }
209 
210 OpSchema& OpSchema::SetDoc(const string& doc) {
211  doc_ = doc;
212  return *this;
213 }
214 
215 OpSchema& OpSchema::Arg(const char* name, const char* description) {
216  arg_desc_.emplace_back(name, description);
217  return *this;
218 }
219 
220 OpSchema& OpSchema::Input(const int n, const char* name, const char* description) {
221  if (input_desc_.size() <= n) {
222  input_desc_.resize(n + 1);
223  }
224  input_desc_[n] = std::make_pair(name, description);
225  return *this;
226 }
227 
228 OpSchema& OpSchema::Output(const int n, const char* name, const char* description) {
229  if (output_desc_.size() <= n) {
230  output_desc_.resize(n + 1);
231  }
232  output_desc_[n] = std::make_pair(name, description);
233  return *this;
234 }
235 
236 OpSchema& OpSchema::FillUsing(std::function<void(OpSchema&)> populator) {
237  if (populator) {
238  populator(*this);
239  }
240  return *this;
241 }
242 
243 int OpSchema::CalculateOutput(int num_input) const {
244  if (min_output_ == max_output_) {
245  return min_output_;
246  } else if (calculate_output_) {
247  return calculate_output_(num_input);
248  } else {
249  return kCannotComputeNumOutputs;
250  }
251 }
252 
253 std::ostream& operator<<(std::ostream& out, const OpSchema& schema) {
254  if (!schema.arg_desc_.empty()) {
255  out << "Arguments:" << std::endl;
256  for (const auto& it : schema.arg_desc_) {
257  out << " " << it.first << " : " << it.second << std::endl;
258  }
259  }
260  if (schema.max_input_ > 0) {
261  out << "Inputs:" << std::endl;
262  if (!schema.input_desc_.empty()) {
263  for (int i = 0; i < schema.input_desc_.size(); ++i) {
264  const auto& p = schema.input_desc_[i];
265  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
266  << (p.second ? p.second : "(no doc)") << std::endl;
267  }
268  } else {
269  out << " (no explicit description available)" << std::endl;
270  }
271  }
272  if (schema.max_output_ > 0) {
273  out << "Outputs:" << std::endl;
274  if (!schema.output_desc_.empty()) {
275  for (int i = 0; i < schema.output_desc_.size(); ++i) {
276  const auto& p = schema.output_desc_[i];
277  out << " " << i << ", " << (p.first ? p.first : "(unnamed)") << " : "
278  << (p.second ? p.second : "(no doc)") << std::endl;
279  }
280  } else {
281  out << " (no explicit description available)" << std::endl;
282  }
283  }
284  out << std::endl;
285  if (schema.doc()) {
286  out << schema.doc();
287  } else {
288  out << "(no documentation yet)" << std::endl;
289  }
290  out << std::endl;
291  if (schema.line_) {
292  out << "Defined at " << schema.file_ << ":" << schema.line_ << std::endl;
293  }
294  return out;
295 }
296 
297 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
298  static CaffeMap<string, OpSchema> map;
299  return map;
300 }
301 
302 } // namespace caffe2
OpSchema & NumInputs(int n)
A single input.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
OpSchema & NumOutputs(int n)
A single output.
const char * doc() const
Returns the docstring of the op schema.
A class to record the schema of an op.
OpSchema & IdenticalTypeAndShape()
Seets the tensor inference function to produce the same output as the input.