1 #include "caffe2/core/operator_schema.h" 3 #include "caffe2/core/logging.h" 9 if (def.input_size() < min_input_ || def.input_size() > max_input_) {
10 LOG(ERROR) <<
"Input size " << def.input_size()
11 <<
" not in range [min=" << min_input_ <<
", max=" 12 << max_input_ <<
"].";
15 if (!num_inputs_allowed_(def.input_size())) {
16 LOG(ERROR) <<
"Input size " << def.input_size()
17 <<
" not in allowed input sizes.";
21 if (def.output_size() < min_output_ || def.output_size() > max_output_) {
22 LOG(ERROR) <<
"Output size " << def.output_size()
23 <<
" not in range [min=" << min_output_ <<
", max=" 24 << max_output_ <<
"].";
27 if (!num_outputs_allowed_(def.output_size())) {
28 LOG(ERROR) <<
"Output size " << def.output_size()
29 <<
" not in allowed output sizes.";
32 if (!num_inputs_outputs_allowed_(def.input_size(), def.output_size())) {
33 LOG(ERROR) <<
"Combination of input size " << def.input_size()
34 <<
"and output size " << def.output_size() <<
" not in allowed.";
38 if (calculate_output_) {
39 int expected_nout = calculate_output_(def.input_size());
40 if (expected_nout != kCannotComputeNumOutputs &&
41 def.output_size() != expected_nout) {
42 LOG(ERROR) <<
"Output size " << def.output_size()
43 <<
" not matching expected output size, which is " 50 for (
int in_idx = 0; in_idx < def.input_size(); ++in_idx) {
51 for (
int out_idx = 0; out_idx < def.output_size(); ++out_idx) {
54 if (def.input(in_idx) == def.output(out_idx) &&
55 (!inplace_allowed_(in_idx, out_idx)
56 && !inplace_enforced_(in_idx, out_idx))) {
57 LOG(ERROR) <<
"Input index " << in_idx <<
" and output idx " << out_idx
58 <<
" (" << def.input(in_idx) <<
")" 59 <<
" are set to be in-place but this is actually not " 60 <<
"supported by op " << def.type();
63 if (def.input(in_idx) != def.output(out_idx) &&
64 inplace_enforced_(in_idx, out_idx)) {
65 LOG(ERROR) <<
"Input index " << in_idx <<
" (" << def.input(in_idx) <<
")" 66 <<
" and output idx " << out_idx
67 <<
" (" << def.output(in_idx) <<
")" 68 <<
" are not in-place but should be as required by op " 90 num_inputs_allowed_ = func;
96 [allowed_input_nums](
int n)->
bool {
97 return allowed_input_nums.count(n);
112 num_outputs_allowed_ = func;
118 [allowed_output_nums](
int n)->
bool {
119 return allowed_output_nums.count(n);
124 num_inputs_outputs_allowed_ = func;
129 calculate_output_ = calc;
137 OpSchema& OpSchema::AllowInplace(std::function<
bool(
int,
int)> inplace) {
138 inplace_allowed_ = inplace;
142 OpSchema& OpSchema::AllowInplace(
set<std::pair<int, int>> inplace) {
144 [inplace](
int in,
int out)->
bool {
145 return inplace.count(std::make_pair(in, out));
149 OpSchema& OpSchema::AllowOneToOneInplace() {
150 return AllowInplace([](
int in,
int out) {
return in == out; });
153 OpSchema& OpSchema::EnforceInplace(std::function<
bool(
int,
int)> inplace) {
154 inplace_enforced_ = inplace;
158 OpSchema& OpSchema::EnforceInplace(
set<std::pair<int, int>> inplace) {
159 return EnforceInplace(
160 [inplace](
int in,
int out)->
bool {
161 return inplace.count(std::make_pair(in, out));
165 OpSchema& OpSchema::EnforceOneToOneInplace() {
166 return EnforceInplace([](
int in,
int out) {
return in == out; });
170 TensorInferenceFunctionType
function) {
171 tensor_inference_function_ =
function;
177 [](
const OperatorDef&,
const vector<TensorShape>& input_types) {
178 return vector<TensorShape>(input_types);
182 OpSchema& OpSchema::IdenticalTypeAndShapeOfInput(
int idx) {
184 [idx](
const OperatorDef&,
const vector<TensorShape>& input_types) {
185 vector<TensorShape> out(1);
186 out[0] = input_types[idx];
191 OpSchema& OpSchema::IdenticalTypeAndShapeOfInputDim(
int idx,
int dim) {
193 [idx, dim](
const OperatorDef&,
const vector<TensorShape>& input_types) {
194 vector<TensorShape> out(1);
195 out[0].add_dims(input_types[idx].dims(dim));
196 out[0].set_data_type(input_types[idx].data_type());
201 OpSchema& OpSchema::ScalarType(::caffe2::TensorProto_DataType dt) {
203 [dt](
const OperatorDef&,
const vector<TensorShape>& input_types) {
204 vector<TensorShape> out(1);
205 out[0].set_data_type(dt);
215 OpSchema& OpSchema::Arg(
const char* name,
const char* description) {
216 arg_desc_.emplace_back(name, description);
220 OpSchema& OpSchema::Input(
const int n,
const char* name,
const char* description) {
221 if (input_desc_.size() <= n) {
222 input_desc_.resize(n + 1);
224 input_desc_[n] = std::make_pair(name, description);
228 OpSchema& OpSchema::Output(
const int n,
const char* name,
const char* description) {
229 if (output_desc_.size() <= n) {
230 output_desc_.resize(n + 1);
232 output_desc_[n] = std::make_pair(name, description);
244 if (min_output_ == max_output_) {
246 }
else if (calculate_output_) {
247 return calculate_output_(num_input);
249 return kCannotComputeNumOutputs;
253 std::ostream& operator<<(std::ostream& out,
const OpSchema& schema) {
254 if (!schema.arg_desc_.empty()) {
255 out <<
"Arguments:" << std::endl;
256 for (
const auto& it : schema.arg_desc_) {
257 out <<
" " << it.first <<
" : " << it.second << std::endl;
260 if (schema.max_input_ > 0) {
261 out <<
"Inputs:" << std::endl;
262 if (!schema.input_desc_.empty()) {
263 for (
int i = 0; i < schema.input_desc_.size(); ++i) {
264 const auto& p = schema.input_desc_[i];
265 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 266 << (p.second ? p.second :
"(no doc)") << std::endl;
269 out <<
" (no explicit description available)" << std::endl;
272 if (schema.max_output_ > 0) {
273 out <<
"Outputs:" << std::endl;
274 if (!schema.output_desc_.empty()) {
275 for (
int i = 0; i < schema.output_desc_.size(); ++i) {
276 const auto& p = schema.output_desc_[i];
277 out <<
" " << i <<
", " << (p.first ? p.first :
"(unnamed)") <<
" : " 278 << (p.second ? p.second :
"(no doc)") << std::endl;
281 out <<
" (no explicit description available)" << std::endl;
288 out <<
"(no documentation yet)" << std::endl;
292 out <<
"Defined at " << schema.file_ <<
":" << schema.line_ << std::endl;
297 CaffeMap<string, OpSchema>& OpSchemaRegistry::map() {
298 static CaffeMap<string, OpSchema> map;
OpSchema & NumInputs(int n)
A single input.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
OpSchema & NumOutputs(int n)
A single output.
const char * doc() const
Returns the docstring of the op schema.
A class to record the schema of an op.
OpSchema & IdenticalTypeAndShape()
Seets the tensor inference function to produce the same output as the input.