Caffe2 - C++ API
A deep learning, cross platform ML framework
operator.h
1 #ifndef CAFFE2_CORE_OPERATOR_H_
2 #define CAFFE2_CORE_OPERATOR_H_
3 
4 #include <climits>
5 #include <cstddef>
6 #include <exception>
7 #include <typeinfo>
8 #include <vector>
9 
10 #include "caffe2/core/blob.h"
11 #include "caffe2/core/common.h"
12 #include "caffe2/core/net.h"
13 #include "caffe2/core/operator_gradient.h"
14 #include "caffe2/core/operator_schema.h"
15 #include "caffe2/core/registry.h"
16 #include "caffe2/core/tensor.h"
17 #include "caffe2/core/workspace.h"
18 #include "caffe2/utils/proto_utils.h"
19 #include "caffe2/proto/caffe2.pb.h"
20 
21 namespace caffe2 {
22 
23 class OperatorBase {
24  public:
25  explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws);
26  virtual ~OperatorBase() noexcept {}
27 
30  inline bool HasArgument(const string& name) const {
31  return arg_helper_.HasArgument(name);
32  }
33 
34  // Functions that deal with arguments. Basically, this allows us to map an
35  // argument name to a specific type of argument that we are trying to access.
36  template <typename T>
37  inline T GetSingleArgument(const string& name, const T& default_value) const {
38  return arg_helper_.template GetSingleArgument<T>(name, default_value);
39  }
40  template <typename T>
41  inline bool HasSingleArgumentOfType(const string& name) const {
42  return arg_helper_.template HasSingleArgumentOfType<T>(name);
43  }
44  template <typename T>
45  inline vector<T> GetRepeatedArgument(
46  const string& name,
47  const vector<T>& default_value = {}) const {
48  return arg_helper_.template GetRepeatedArgument<T>(name, default_value);
49  }
50 
51  // Get the inputs and outputs as specific types.
52  template <typename T>
53  inline const T& Input(int idx) {
54  DCHECK_LT(idx, inputs_.size());
55  try {
56  return inputs_.at(idx)->template Get<T>();
57  } catch (::caffe2::EnforceNotMet& enf) {
58  enf.AppendMessage(".\nOffending Blob name: ");
59  enf.AppendMessage(operator_def_.input(idx));
60  enf.AppendMessage(".\n");
61  throw enf;
62  }
63  }
64 
65  template <typename T>
66  inline T* Output(int idx) {
67  return outputs_.at(idx)->template GetMutable<T>();
68  }
69 
70  inline const Blob& InputBlob(int idx) {
71  return *inputs_.at(idx);
72  }
73 
74  inline Blob* OutputBlob(int idx) {
75  return outputs_.at(idx);
76  }
77 
78  template <typename T>
79  inline bool InputIsType(int idx) {
80  return inputs_.at(idx)->template IsType<T>();
81  }
82 
83  template <typename T>
84  inline bool OutputIsType(int idx) {
85  return outputs_.at(idx)->template IsType<T>();
86  }
87 
88  inline int InputSize() { return inputs_.size(); }
89  inline int OutputSize() { return outputs_.size(); }
90  inline const vector<const Blob*>& Inputs() const { return inputs_; }
91  inline const vector<Blob*>& Outputs() { return outputs_; }
92 
93  virtual bool Run(int /* unused */ stream_id = 0) {
94  CAFFE_NOT_IMPLEMENTED;
95  }
96 
97  virtual bool RunAsync(int /* unused */ stream_id = 0) {
98  return Run(stream_id);
99  }
100 
101  virtual void AddRelatedBlobInfo(EnforceNotMet* err) {
102  bool found_input;
103  if (err->caller() != nullptr) {
104  for (int i = 0; i < inputs_.size(); i++) {
105  if (inputs_[i]->GetRaw() == err->caller()) {
106  found_input = true;
107  err->AppendMessage("\n** while accessing input: " + def().input(i));
108  break;
109  }
110  }
111  for (int i = 0; i < outputs_.size(); i++) {
112  if (outputs_[i]->GetRaw() == err->caller()) {
113  if (found_input) {
114  err->AppendMessage("\n OR ");
115  }
116  err->AppendMessage("\n** while accessing output: " + def().output(i));
117  break;
118  }
119  }
120  }
121  }
122 
123  inline const OperatorDef& def() const {
124  return operator_def_;
125  }
126  inline const ArgumentHelper& arg_helper() const {
127  return arg_helper_;
128  }
129 
130  private:
131  OperatorDef operator_def_;
132  ArgumentHelper arg_helper_;
133  vector<const Blob*> inputs_;
134  vector<Blob*> outputs_;
135 
136  DISABLE_COPY_AND_ASSIGN(OperatorBase);
137 };
138 
139 // If your operator does not need any specialized contructor or destructor,
140 // you can simply use this to save two lines of code.
141 #define USE_SIMPLE_BASE_CTOR_DTOR(name) \
142  name(const OperatorDef& operator_def, Workspace* ws) \
143  : OperatorBase(operator_def, ws) {} \
144  virtual ~name() noexcept {}
145 
146 // OP_SINGLE_ARG provides a shorter initialization choice for initialization of
147 // member variables for the class constructors.
148 #define OP_SINGLE_ARG(type, name, variable, default) \
149  variable(OperatorBase::GetSingleArgument<type>(name, (default)))
150 
151 // INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the
152 // operator's inputs and outputs, in order to avoid confusion. For example, for
153 // a fully convolution layer that has input, weight and bias, you can define its
154 // input tags as:
155 // INPUT_TAGS(INPUT, WEIGHT, BIAS);
156 // And in the code, instead of doing
157 // auto& weight = Input(1);
158 // you can now do
159 // auto& weight = Input(WEIGHT);
160 // to make it more clear.
161 #define INPUT_TAGS(first_input, ...) \
162  enum _InputTags { first_input = 0, __VA_ARGS__ }
163 #define OUTPUT_TAGS(first_input, ...) \
164  enum _OutputTags { first_input = 0, __VA_ARGS__ }
165 
166 
167 // Operator is the class that you usually want to derive, if your operator will
168 // run on different devices. You should then implement the RunOnDevice()
169 // function.
170 template <class Context>
171 class Operator : public OperatorBase {
172  public:
173  explicit Operator(const OperatorDef& operator_def, Workspace* ws)
174  : OperatorBase(operator_def, ws),
175  context_(operator_def.device_option()) {
176  // In the constructor, we switch to the device so that the child class
177  // constructors will run on that device.
178  context_.SwitchToDevice(0);
179  }
180  virtual ~Operator() noexcept {}
181 
182  inline const Tensor<Context>& Input(int idx) {
183  return OperatorBase::template Input<Tensor<Context> >(idx); }
184  inline Tensor<Context>* Output(int idx) {
185  return OperatorBase::template Output<Tensor<Context> >(idx);
186  }
187 
188  // The run function of Operator switches to the device, and then carries out
189  // the actual computation with RunOnDevice(). You should implement RunOnDevice
190  // instead of Run().
191  bool Run(int stream_id = 0) final {
192  try {
193  context_.SwitchToDevice(stream_id);
194  bool started = RunOnDevice();
195  bool finished = context_.FinishDeviceComputation();
196  if (!finished) {
197  // FinishDeviceComputation() returning error basically means that there
198  // is something wrong with the device (like CUDA) that usually cannot be
199  // recovered, so we should log FATAL.
200  LOG(FATAL) << "Computation on device returned error in operator\n"
201  << ProtoDebugString(this->def());
202  }
203  return (started && finished);
204  } catch (EnforceNotMet& err) {
205  err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
206  AddRelatedBlobInfo(&err);
207  throw;
208  }
209  }
210 
211  bool RunAsync(int stream_id = 0) final {
212  try {
213  context_.SwitchToDevice(stream_id);
214  return RunOnDevice();
215  } catch (EnforceNotMet& err) {
216  err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
217  AddRelatedBlobInfo(&err);
218  throw;
219  }
220  }
221 
222  virtual bool RunOnDevice() = 0;
223 
224  protected:
225  Context context_;
226 };
227 
228 #define USE_OPERATOR_BASE_FUNCTIONS \
229  /* using override */ using OperatorBase::HasArgument; \
230  /* using override */ using OperatorBase::GetSingleArgument; \
231  /* using override */ using OperatorBase::HasSingleArgumentOfType; \
232  /* using override */ using OperatorBase::GetRepeatedArgument; \
233  /* using override */ using OperatorBase::def; \
234  /* using override */ using OperatorBase::InputIsType; \
235  /* using override */ using OperatorBase::InputSize; \
236  /* using override */ using OperatorBase::OutputSize
237 
238 #define USE_OPERATOR_FUNCTIONS(context) \
239  USE_OPERATOR_BASE_FUNCTIONS; \
240  /* using override */ using Operator<context>::context_; \
241  /* using override */ using Operator<context>::Input; \
242  /* using override */ using Operator<context>::Output
243 
244 #define USE_OPERATOR_CONTEXT_FUNCTIONS USE_OPERATOR_FUNCTIONS(Context)
245 
246 #define USE_SIMPLE_CTOR_DTOR(name) \
247  name(const OperatorDef& operator_def, Workspace* ws) \
248  : Operator<Context>(operator_def, ws) {} \
249  virtual ~name() noexcept {}
250 
251 // Helpers to implement runtime op polymorphism. Often it's convenient to make
252 // an op work on different input types (e.g. i32 vs i64 indices) or special-case
253 // it for particular input size (e.g. ScatterWeightedSum for block size of 1
254 // doesn't need to call Eigen).
255 //
256 // DispatchHelper provides compile-time generation of nested "if" statements,
257 // e.g. `DispatchHelper<FixedValues<1, 4>>::call(this, block_size);`
258 // unrolls into:
259 // if (block_size == 1) {
260 // return DoRunWithValue<1>();
261 // } else if (block_size = 4) {
262 // return DoRunWithValue<4>();
263 // } else {
264 // return DoRunWithValue<-1>();
265 // }`
266 //
267 // DoRunWithValue implementation can use template arguments to do "if"
268 // statements
269 // or proxy to functions in math.h which often provide fixed size
270 // implementation.
271 //
272 // Similarly `TensorTypes<int32_t, int64_t>(this, Input(0))` provides branching
273 // based on type of the first input and calls DoRunWithType.
274 //
275 // Note, that the same instance of Op class is used as the method, not class is
276 // templated. We might consider adding static class-level polymorphism later.
277 //
278 // Convenient macro USE_DISPATCH_HELPER is provided for declaring friendship in
279 // case DoRunWithValue or DoRunWithType are declared non-public.
280 
281 #define USE_DISPATCH_HELPER \
282  template <typename FirstArg, typename... ExtraArgs> \
283  friend struct DispatchHelper
284 
285 template <int... Values>
286 struct FixedValues {};
287 
288 template <typename... Types>
289 struct TensorTypes {};
290 
291 // Special tag that can be listed in TensorTypes to denote that a special
292 // implementation in 'RunWithOtherType' needs to be called instead of failing
293 // Obviously this needs to be the last item in lists, e.g.
294 // TensorTypes<float, double, GenericTensorImplementation>
296 
297 // Same as TensorTypes but call DoRunWithType2
298 template <typename... Types>
299 struct TensorTypes2 {};
300 
301 template <typename Sizes, typename... ExtraArgs>
303 
304 template <int FirstVal, int... Values, typename... ExtraArgs>
305 struct DispatchHelper<FixedValues<FirstVal, Values...>, ExtraArgs...> {
306  template <typename Op>
307  static bool call(Op* op, int value) {
308  if (FirstVal == value) {
309  return op->template DoRunWithValue<ExtraArgs..., FirstVal>();
310  }
311  return DispatchHelper<FixedValues<Values...>, ExtraArgs...>::template call<
312  Op>(op, value);
313  }
314 };
315 
316 template <typename... ExtraArgs>
317 struct DispatchHelper<FixedValues<>, ExtraArgs...> {
318  template <typename Op>
319  static bool call(Op* op, TIndex size) {
320  return op->template DoRunWithValue<ExtraArgs..., -1>();
321  }
322 };
323 
324 #define CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER( \
325  TensorTypes, DoRunWithType, DoRunWithOtherType) \
326  template <typename FirstType, typename... Types, typename... ExtraArgs> \
327  struct DispatchHelper<TensorTypes<FirstType, Types...>, ExtraArgs...> { \
328  template <typename Op> \
329  static bool call(Op* op, const TypeMeta& meta) { \
330  static_assert( \
331  !std::is_same<GenericTensorImplementation, FirstType>::value, \
332  "GenericTensorImplementation must be the last in TensorTypes list"); \
333  if (meta.Match<FirstType>()) { \
334  return op->template DoRunWithType<ExtraArgs..., FirstType>(); \
335  } \
336  return DispatchHelper<TensorTypes<Types...>, ExtraArgs...>:: \
337  template call<Op>(op, meta); \
338  } \
339  template <typename Op, typename Context> \
340  static bool call(Op* op, const Tensor<Context>& tensor) { \
341  return call<Op>(op, tensor.meta()); \
342  } \
343  }; \
344  \
345  template <typename... ExtraArgs> \
346  struct DispatchHelper<TensorTypes<>, ExtraArgs...> { \
347  template <typename Op> \
348  static bool call(Op* /* unused */, const TypeMeta& meta) { \
349  CAFFE_THROW("Unsupported type of tensor: ", meta.name()); \
350  } \
351  template <typename Op, typename Context> \
352  static bool call(Op* op, const Tensor<Context>& tensor) { \
353  return call<Op>(op, tensor.meta()); \
354  } \
355  }; \
356  \
357  template <typename... ExtraArgs> \
358  struct DispatchHelper< \
359  TensorTypes<GenericTensorImplementation>, \
360  ExtraArgs...> { \
361  template <typename Op> \
362  static bool call(Op* op, const TypeMeta& meta) { \
363  return op->template DoRunWithOtherType<ExtraArgs...>(); \
364  } \
365  template <typename Op, typename Context> \
366  static bool call(Op* op, const Tensor<Context>& tensor) { \
367  return call<Op>(op, tensor.meta()); \
368  } \
369  };
370 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
371  TensorTypes,
372  DoRunWithType,
373  DoRunWithOtherType)
374 CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER(
375  TensorTypes2,
376  DoRunWithType2,
377  DoRunWithOtherType2)
378 #undef CAFFE2_DEFINE_TENSOR_TYPES_DISPATCHER
379 
380 // The device type registry. This works in two phases:
381 // (1) gDeviceTypeRegistry() maps the device types values to the actual operator
382 // registry function.
383 // (2) Then, one can call the operator registry function to further create the
384 // operators.
386  OperatorRegistry;
388  *RegistryFunction)();
389 std::map<int32_t, OperatorRegistry*>* gDeviceTypeRegistry();
390 
392  explicit DeviceTypeRegisterer(int32_t type, RegistryFunction func) {
393  if (gDeviceTypeRegistry()->count(type)) {
394  std::cerr << "Device type " << type
395  << "registered twice. This should not happen. Did you have "
396  "duplicated numbers assigned to different devices?";
397  std::exit(1);
398  }
399  // Calling the registry function to get the actual registry pointer.
400  gDeviceTypeRegistry()->emplace(type, func());
401  }
402 };
403 
404 #define CAFFE_REGISTER_DEVICE_TYPE(type, registry_function) \
405  namespace { \
406  static DeviceTypeRegisterer CAFFE_ANONYMOUS_VARIABLE( \
407  DeviceType)(type, &registry_function); \
408  }
409 
410 // The operator registry. Since we are not expecting a great number of devices,
411 // we will simply have an if-then type command and allocate the actual
412 // generation to device-specific registerers.
413 // Note that although we have CUDA and CUDNN here, the registerers themselves do
414 // not depend on specific cuda or cudnn libraries. This means that we will be
415 // able to compile it even when there is no cuda available - we simply do not
416 // link any cuda or cudnn operators.
417 CAFFE_DECLARE_REGISTRY(
418  CPUOperatorRegistry,
419  OperatorBase,
420  const OperatorDef&,
421  Workspace*);
422 #define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \
423  CAFFE_REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__)
424 #define REGISTER_CPU_OPERATOR(name, ...) \
425  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__)
426 #define REGISTER_CPU_OPERATOR_STR(str_name, ...) \
427  CAFFE_REGISTER_TYPED_CLASS(CPUOperatorRegistry, str_name, __VA_ARGS__)
428 
429 #define REGISTER_CPU_OPERATOR_WITH_ENGINE(name, engine, ...) \
430  CAFFE_REGISTER_CLASS(CPUOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
431 
432 CAFFE_DECLARE_REGISTRY(
433  CUDAOperatorRegistry,
434  OperatorBase,
435  const OperatorDef&,
436  Workspace*);
437 #define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \
438  CAFFE_REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__)
439 #define REGISTER_CUDA_OPERATOR(name, ...) \
440  CAFFE_REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__)
441 #define REGISTER_CUDA_OPERATOR_STR(str_name, ...) \
442  CAFFE_REGISTER_TYPED_CLASS(CUDAOperatorRegistry, str_name, __VA_ARGS__)
443 
444 #define REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, engine, ...) \
445  CAFFE_REGISTER_CLASS( \
446  CUDAOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
447 
448 // Macros for cudnn since we use it often
449 #define REGISTER_CUDNN_OPERATOR(name, ...) \
450  REGISTER_CUDA_OPERATOR_WITH_ENGINE(name, CUDNN, __VA_ARGS__)
451 
452 // An exception that can be thrown by an operator constructor that notifies
453 // that it does not support the given setting. This can be usually used for
454 // specific engines that only implement a subset of the features required by
455 // the original operator schema.
456 // TODO(jiayq): make more feature-complete exception message.
457 class UnsupportedOperatorFeature : public std::exception {
458  public:
459  UnsupportedOperatorFeature(const string& msg) : msg_(msg) {}
460  const char* what() const noexcept override {
461  return msg_.c_str();
462  }
463 
464  private:
465  string msg_;
466 };
467 
468 // A helper macro that should ONLY be used in the operator constructor to check
469 // if needed features are met. If not, throws the UnsupportedOperatorFeature
470 // exception with the given message.
471 #define OPERATOR_NEEDS_FEATURE(condition, ...) \
472  if (!(condition)) { \
473  throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \
474  }
475 
476 // Creates an operator with the given operator definition.
477 // Throws on error and never returns nullptr
478 unique_ptr<OperatorBase> CreateOperator(
479  const OperatorDef& operator_def, Workspace* ws);
480 
481 TensorShapes InferBlobShapesAndTypesFromWorkspace(
482  Workspace* ws,
483  const vector<std::unique_ptr<NetDef>>& nets);
484 
485 TensorShapes InferBlobShapesAndTypesFromMap(
486  const CaffeMap<std::string, std::vector<TIndex>>& blob_dimensions,
487  const vector<std::unique_ptr<NetDef>>& nets);
488 
489 } // namespace caffe2
490 
491 #endif // CAFFE2_CORE_OPERATOR_H_
A template class that allows one to register classes by keys.
Definition: registry.h:31
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:30
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:53
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:73
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
Blob is a general container that hosts a typed pointer.
Definition: blob.h:25
A helper class to index into arguments.
Definition: proto_utils.h:174