1 #ifndef CAFFE2_CORE_OPERATOR_SCHEMA_H_ 2 #define CAFFE2_CORE_OPERATOR_SCHEMA_H_ 6 #include <initializer_list> 11 #include "caffe2/core/common.h" 12 #include "caffe2/core/registry.h" 13 #include "caffe2/proto/caffe2.pb.h" 19 constexpr
int kCannotComputeNumOutputs = -1;
38 OpSchema() : file_(
"unknown"), line_(0) {}
40 : file_(file), line_(line) {}
45 inline const string&
file()
const {
return file_; }
50 inline int line()
const {
return line_; }
55 inline const char*
doc()
const {
56 return doc_.empty() ? nullptr : doc_.c_str();
63 bool Verify(
const OperatorDef& def)
const;
123 OpSchema& AllowInplace(std::function<
bool(
int,
int)> inplace);
124 OpSchema& AllowInplace(
set<std::pair<int, int>> inplace);
127 OpSchema& EnforceInplace(std::function<
bool(
int,
int)> inplace);
128 OpSchema& EnforceInplace(
set<std::pair<int, int>> inplace);
135 typedef std::function<
136 vector<TensorShape>(
const OperatorDef&,
const vector<TensorShape>&)>
137 TensorInferenceFunctionType;
148 OpSchema& IdenticalTypeAndShapeOfInput(
int idx);
149 OpSchema& IdenticalTypeAndShapeOfInputDim(
int idx,
int dim);
150 OpSchema& ScalarType(::caffe2::TensorProto_DataType dt);
157 const OperatorDef& def,
158 const vector<TensorShape> input_type_shape)
const {
159 return tensor_inference_function_(def, input_type_shape);
164 OpSchema& Arg(
const char* name,
const char* description);
165 OpSchema& Input(
const int n,
const char* name,
const char* description);
166 OpSchema& Output(
const int n,
const char* name,
const char* description);
177 friend std::ostream& operator<<(std::ostream& out,
const OpSchema& schema);
179 const std::vector<std::pair<const char*, const char*>>& arg_desc() {
182 const std::vector<std::pair<const char*, const char*>>& input_desc() {
185 const std::vector<std::pair<const char*, const char*>>& output_desc() {
192 std::vector<std::pair<const char*, const char*>> arg_desc_{};
193 std::vector<std::pair<const char*, const char*>> input_desc_{};
194 std::vector<std::pair<const char*, const char*>> output_desc_{};
197 int max_input_ = std::numeric_limits<int>::max();
199 int max_output_ = std::numeric_limits<int>::max();
200 std::function<bool(int)> num_inputs_allowed_
201 = [](int) {
return true; };
202 std::function<bool(int)> num_outputs_allowed_
203 = [](int) {
return true; };
204 std::function<bool(int, int)> num_inputs_outputs_allowed_
205 = [](int, int) {
return true; };
206 std::function<int(int)> calculate_output_;
208 std::function<bool(int, int)> inplace_allowed_
209 = [](int, int) {
return false; };
210 std::function<bool(int, int)> inplace_enforced_
211 = [](int, int) {
return false; };
212 TensorInferenceFunctionType tensor_inference_function_ =
213 [](
const OperatorDef& def,
const vector<TensorShape>&) {
214 vector<TensorShape> out;
215 for(
int i=0; i<def.output_size(); i++) {
217 ts.set_unknown_shape(
true);
230 const string& key,
const string&
file,
const int line) {
233 const auto& schema = m[key];
234 std::cerr <<
"Trying to register schema with name " 235 << key <<
" from file " << file <<
" line " << line
236 <<
", but it is already registered from file " 237 << schema.file() <<
" line " << schema.line();
240 m.emplace(std::make_pair(key,
OpSchema(file, line)));
244 static const OpSchema* Schema(
const string& key) {
267 static CaffeMap<string, OpSchema>& map();
271 inline TensorShape CreateTensorShape(
273 ::caffe2::TensorProto_DataType dt) {
278 ts.set_data_type(dt);
283 inline vector<TIndex> GetDimsVector(
const TensorShape& shape) {
285 for (
auto d : shape.dims()) {
293 #define OPERATOR_SCHEMA(name) \ 294 static OpSchema& CAFFE_ANONYMOUS_VARIABLE(name) = \ 295 OpSchemaRegistry::NewSchema(#name, __FILE__, __LINE__) 296 #define OPERATOR_SCHEMA_STR(name) \ 297 static OpSchema& CAFFE_ANONYMOUS_VARIABLE(schema_registration) = \ 298 OpSchemaRegistry::NewSchema(name, __FILE__, __LINE__) 300 #endif // CAFFE2_CORE_OPERATOR_SCHEMA_H_ OpSchema & NumInputs(int n)
A single input.
A registry to hold all the operator schemas.
vector< TensorShape > InferTensor(const OperatorDef &def, const vector< TensorShape > input_type_shape) const
A function to allow one to infer the type and shape from the op schema.
OpSchema & NumInputsOutputs(std::function< bool(int, int)> func)
Relationship between inputs and outputs is checked with a specified function.
OpSchema & OutputCalculator(std::function< int(int)> calc)
Set the output calculator to a user-defined function.
bool Verify(const OperatorDef &def) const
Verifies if an operator definition protobuf matches the pattern specified in the schema.
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
int CalculateOutput(int num_input) const
A function to allow one to get the number of outputs based on the number of inputs, if this schema supports it.
OpSchema & SameNumberOfOutput()
Set the number of outputs to be the same as the number of inputs.
OpSchema & TensorInferenceFunction(TensorInferenceFunctionType function)
Sets the tensor inference function, which is a std::function object defined in operator_schema.h.
OpSchema & NumOutputs(int n)
A single output.
const char * doc() const
Returns the docstring of the op schema.
int line() const
Returns the line in file that the op schema is registered from.
A class to record the schema of an op.
OpSchema & IdenticalTypeAndShape()
Seets the tensor inference function to produce the same output as the input.
const string & file() const
Returns the file that the op schema is registered from.