Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_operator.h
1 #ifndef CAFFE2_UTILS_MKL_OPERATOR_H_
2 #define CAFFE2_UTILS_MKL_OPERATOR_H_
3 
4 #include "caffe2/core/operator.h"
5 #include "caffe2/proto/caffe2.pb.h"
6 #include "caffe2/utils/mkl/mkl_dnn_cppwrapper.h"
7 #include "caffe2/utils/mkl/mkl_memory.h"
8 
9 namespace caffe2 {
10 
11 CAFFE_DECLARE_REGISTRY(
12  MKLOperatorRegistry,
13  OperatorBase,
14  const OperatorDef&,
15  Workspace*);
16 #define REGISTER_MKL_OPERATOR_CREATOR(key, ...) \
17  CAFFE_REGISTER_CREATOR(MKLOperatorRegistry, key, __VA_ARGS__)
18 #define REGISTER_MKL_OPERATOR(name, ...) \
19  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name, __VA_ARGS__)
20 #define REGISTER_MKL_OPERATOR_STR(str_name, ...) \
21  CAFFE_REGISTER_TYPED_CLASS(MKLOperatorRegistry, str_name, __VA_ARGS__)
22 
23 #define REGISTER_MKL_OPERATOR_WITH_ENGINE(name, engine, ...) \
24  CAFFE_REGISTER_CLASS(MKLOperatorRegistry, name##_ENGINE_##engine, __VA_ARGS__)
25 
26 namespace mkl {
27 // MKLOperator is the base scaffolding of the operators that uses MKLDNN. It
28 // provides a few operators that are useful to MKLDNN specific implementations.
29 template <typename T>
30 class MKLOperator : public OperatorBase {
31  public:
32  explicit MKLOperator(const OperatorDef& operator_def, Workspace* ws)
33  : OperatorBase(operator_def, ws) {}
34  virtual ~MKLOperator() {}
35 
36  inline const MKLMemory<T>& Input(int idx) {
37  return OperatorBase::template Input<MKLMemory<T>>(idx);
38  }
39  inline MKLMemory<T>* Output(int idx) {
40  return OperatorBase::template Output<MKLMemory<T>>(idx);
41  }
42 
43  // The run function of Operator switches to the device, and then carries out
44  // the actual computation with RunOnDevice(). You should implement RunOnDevice
45  // instead of Run().
46  bool Run(int /* unused */ stream_id) final {
47  // Since MKLDNN does not need to do SwithToDevice and
48  // FinishDeviceComputation,
49  // it is always just a re-route to RunOnDevice().
50  try {
51  return RunOnDevice();
52  } catch (EnforceNotMet& err) {
53  err.AppendMessage("Error from operator: \n" + ProtoDebugString(def()));
54  throw;
55  }
56  }
57 
58  virtual bool RunOnDevice() = 0;
59 
60  inline void ExecutePrimitive() {
61  MKLDNN_SAFE_CALL(mkl::dnnExecute<T>(primitive_, resources_));
62  }
63 
64  protected:
65  // The primitive used in the operator.
66  PrimitiveWrapper<T> primitive_;
67  // Size cache for all the input sizes.
68  vector<vector<TIndex>> input_size_cache_;
69  // An internal MKLMemory buffer. This is usually handy when we have a
70  // single output from the operator. If your operator has multiple outputs
71  // then you should allocate your own buffer.
72  MKLMemory<T> buffer_;
73  // The resources vector that we will need to use;
74  void* resources_[dnnResourceNumber];
75 };
76 } // namespace mkl
77 
78 #define USE_MKLOPERATOR_FUNCTIONS(T) \
79  USE_OPERATOR_BASE_FUNCTIONS; \
80  /* using override */ using MKLOperator<T>::Input; \
81  /* using override */ using MKLOperator<T>::Output; \
82  /* using override */ using MKLOperator<T>::ExecutePrimitive; \
83  /* using override */ using MKLOperator<T>::primitive_; \
84  /* using override */ using MKLOperator<T>::input_size_cache_; \
85  /* using override */ using MKLOperator<T>::buffer_; \
86  /* using override */ using MKLOperator<T>::resources_
87 
88 #define USE_SIMPLE_MKL_CTOR_DTOR(name, T) \
89  name(const OperatorDef& operator_def, Workspace* ws) \
90  : MKLOperator<T>(operator_def, ws) {} \
91  virtual ~name() {}
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_UTILS_MKL_OPERATOR_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:53
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
A wrapper around an opaque MKL internal resource that has certain layouts and convertion primitives s...
Definition: mkl_memory.h:137