1 #ifndef CAFFE2_CORE_OPERATOR_H_ 2 #define CAFFE2_CORE_OPERATOR_H_ 10 #include "caffe2/core/blob.h" 11 #include "caffe2/core/common.h" 12 #include "caffe2/core/net.h" 13 #include "caffe2/core/operator_gradient.h" 14 #include "caffe2/core/operator_schema.h" 15 #include "caffe2/core/registry.h" 16 #include "caffe2/core/tensor.h" 17 #include "caffe2/core/workspace.h" 18 #include "caffe2/utils/proto_utils.h" 19 #include "caffe2/proto/caffe2.pb.h" 31 return arg_helper_.HasArgument(name);
37 inline T GetSingleArgument(
const string& name,
const T& default_value)
const {
38 return arg_helper_.template GetSingleArgument<T>(name, default_value);
41 inline bool HasSingleArgumentOfType(
const string& name)
const {
42 return arg_helper_.template HasSingleArgumentOfType<T>(name);
45 inline vector<T> GetRepeatedArgument(
47 const vector<T>& default_value = {})
const {
48 return arg_helper_.template GetRepeatedArgument<T>(name, default_value);
53 inline const T& Input(
int idx) {
54 DCHECK_LT(idx, inputs_.size());
56 return inputs_.at(idx)->template Get<T>();
58 enf.AppendMessage(
".\nOffending Blob name: ");
59 enf.AppendMessage(operator_def_.input(idx));
60 enf.AppendMessage(
".\n");
66 inline T* Output(
int idx) {
67 return outputs_.at(idx)->template GetMutable<T>();
70 inline const Blob& InputBlob(
int idx) {
71 return *inputs_.at(idx);
74 inline Blob* OutputBlob(
int idx) {
75 return outputs_.at(idx);
79 inline bool InputIsType(
int idx) {
80 return inputs_.at(idx)->template IsType<T>();
84 inline bool OutputIsType(
int idx) {
85 return outputs_.at(idx)->template IsType<T>();
88 inline int InputSize() {
return inputs_.size(); }
89 inline int OutputSize() {
return outputs_.size(); }
90 inline const vector<const Blob*>& Inputs()
const {
return inputs_; }
91 inline const vector<Blob*>& Outputs() {
return outputs_; }
93 virtual bool Run(
int stream_id = 0) {
94 CAFFE_NOT_IMPLEMENTED;
97 virtual bool RunAsync(
int stream_id = 0) {
98 return Run(stream_id);
103 if (err->caller() !=
nullptr) {
104 for (
int i = 0; i < inputs_.size(); i++) {
105 if (inputs_[i]->GetRaw() == err->caller()) {
107 err->AppendMessage(
"\n** while accessing input: " + def().input(i));
111 for (
int i = 0; i < outputs_.size(); i++) {
112 if (outputs_[i]->GetRaw() == err->caller()) {
114 err->AppendMessage(
"\n OR ");
116 err->AppendMessage(
"\n** while accessing output: " + def().output(i));
123 inline const OperatorDef& def()
const {
124 return operator_def_;
131 OperatorDef operator_def_;
133 vector<const Blob*> inputs_;
134 vector<Blob*> outputs_;
141 #define USE_SIMPLE_BASE_CTOR_DTOR(name) \ 142 name(const OperatorDef& operator_def, Workspace* ws) \ 143 : OperatorBase(operator_def, ws) {} \ 144 virtual ~name() noexcept {} 148 #define OP_SINGLE_ARG(type, name, variable, default) \ 149 variable(OperatorBase::GetSingleArgument<type>(name, (default))) 161 #define INPUT_TAGS(first_input, ...) \ 162 enum _InputTags { first_input = 0, __VA_ARGS__ } 163 #define OUTPUT_TAGS(first_input, ...) \ 164 enum _OutputTags { first_input = 0, __VA_ARGS__ } 170 template <
class Context>
175 context_(operator_def.device_option()) {
178 context_.SwitchToDevice(0);
183 return OperatorBase::template Input<Tensor<Context> >(idx); }
185 return OperatorBase::template Output<Tensor<Context> >(idx);
191 bool Run(
int stream_id = 0)
final {
193 context_.SwitchToDevice(stream_id);
194 bool started = RunOnDevice();
195 bool finished = context_.FinishDeviceComputation();
200 LOG(FATAL) <<
"Computation on device returned error in operator\n" 201 << ProtoDebugString(this->def());
203 return (started && finished);
205 err.AppendMessage(
"Error from operator: \n" + ProtoDebugString(def()));
206 AddRelatedBlobInfo(&err);
211 bool RunAsync(
int stream_id = 0)
final {
213 context_.SwitchToDevice(stream_id);
214 return RunOnDevice();
216 err.AppendMessage(
"Error from operator: \n" + ProtoDebugString(def()));
217 AddRelatedBlobInfo(&err);
222 virtual bool RunOnDevice() = 0;
228 #define USE_OPERATOR_BASE_FUNCTIONS \ 229 using OperatorBase::HasArgument; \ 230 using OperatorBase::GetSingleArgument; \ 231 using OperatorBase::HasSingleArgumentOfType; \ 232 using OperatorBase::GetRepeatedArgument; \ 233 using OperatorBase::def; \ 234 using OperatorBase::InputIsType; \ 235 using OperatorBase::InputSize; \ 236 using OperatorBase::OutputSize 238 #define USE_OPERATOR_FUNCTIONS(context) \ 239 USE_OPERATOR_BASE_FUNCTIONS; \ 240 using Operator<context>::context_; \ 241 using Operator<context>::Input; \ 242 using Operator<context>::Output 244 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context) 246 #define USE_SIMPLE_CTOR_DTOR(name) \ 247 name(const OperatorDef& operator_def, Workspace* ws) \ 248 : Operator<Context>(operator_def, ws) {} \ 249 virtual ~name() noexcept {} 281 #define USE_DISPATCH_HELPER \ 282 template <typename FirstArg, typename... ExtraArgs> \ 283 friend struct DispatchHelper 285 template <
int... Values>
288 template <
typename... Types>
298 template <
typename... Types>
301 template <
typename Sizes,
typename... ExtraArgs>
304 template <
int FirstVal,
int... Values,
typename... ExtraArgs>
306 template <
typename Op>
307 static bool call(Op* op,
int value) {
308 if (FirstVal == value) {
309 return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
316 template <
typename... ExtraArgs>
318 template <
typename Op>
319 static bool call(Op* op, TIndex size) {
320 return op->template DoRunWithValue<ExtraArgs..., -1>();
324 #define CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER( \ 325 TensorTypes, DoRunWithType, DoRunWithOtherType) \ 326 template <typename FirstType, typename... Types, typename... ExtraArgs> \ 327 struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \ 328 template <typename Op> \ 329 static bool call(Op* op, const TypeMeta& meta) { \ 331 !std::is_same<GenericTensorImplementation, FirstType>::value, \ 332 "GenericTensorImplementation must be the last in TensorTypes list"); \ 333 if (meta.Match<FirstType>()) { \ 334 return op->template DoRunWithType<ExtraArgs..., FirstType>(); \ 336 return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \ 337 template call<Op>(op, meta); \ 339 template <typename Op, typename Context> \ 340 static bool call(Op* op, const Tensor<Context>& tensor) { \ 341 return call<Op>(op, tensor.meta()); \ 345 template <typename... ExtraArgs> \ 346 struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \ 347 template <typename Op> \ 348 static bool call(Op* , const TypeMeta& meta) { \ 349 CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \ 351 template <typename Op, typename Context> \ 352 static bool call(Op* op, const Tensor<Context>& tensor) { \ 353 return call<Op>(op, tensor.meta()); \ 357 template <typename... ExtraArgs> \ 358 struct DispatchHelper< \ 359 TensorTypes<GenericTensorImplementation>, \ 361 template <typename Op> \ 362 static bool call(Op* op, const TypeMeta& meta) { \ 363 return op->template DoRunWithOtherType<ExtraArgs...>(); \ 365 template <typename Op, typename Context> \ 366 static bool call(Op* op, const Tensor<Context>& tensor) { \ 367 return call<Op>(op, tensor.meta()); \ 370 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
374 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
378 #undef CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER 388 *RegistryFunction)();
389 std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry();
393 if (gDeviceTypeRegistry()->count(type)) {
394 std::cerr <<
"Device type " << type
395 <<
"registered twice. This should not happen. Did you have " 396 "duplicated numbers assigned to different devices?";
400 gDeviceTypeRegistry()->emplace(type, func());
404 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \ 406 static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \ 407 DeviceType)(type, ®istry_function); \ 417 CAFFE_DECLARE_REGISTRY(
422 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ 423 CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) 424 #define REGISTER_CPU_OPERATOR(name, ...) \ 425 CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) 426 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \ 427 CAFFE_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__) 429 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \ 430 CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 432 CAFFE_DECLARE_REGISTRY(
433 CUDAOperatorRegistry,
437 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ 438 CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) 439 #define REGISTER_CUDA_OPERATOR(name, ...) \ 440 CAFFE_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) 441 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \ 442 CAFFE_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__) 444 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \ 445 CAFFE_REGISTER_CLASS( \ 446 CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__) 449 #define REGISTER_CUDNN_OPERATOR(name, ...) \ 450 REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__) 460 const char* what()
const noexcept
override {
471 #define OPERATOR_NEEDS_FEATURE(condition, ...) \ 472 if (!(condition)) { \ 473 throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \ 478 unique_ptr<OperatorBase> CreateOperator(
479 const OperatorDef& operator_def,
Workspace* ws);
481 TensorShapes InferBlobShapesAndTypesFromWorkspace(
483 const vector<std::unique_ptr<NetDef>>& nets);
485 TensorShapes InferBlobShapesAndTypesFromMap(
486 const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
487 const vector<std::unique_ptr<NetDef>>& nets);
491 #endif // CAFFE2_CORE_OPERATOR_H_ A template class that allows one to register classes by keys.
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
A helper class to index into arguments.