Caffe2 - C++ API
A deep learning, cross platform ML framework
mkl_context.h
1 #ifndef CAFFE2_UTILS_MKL_CONTEXT_H_
2 #define CAFFE2_UTILS_MKL_CONTEXT_H_
3 
4 #include <cstdlib>
5 #include <ctime>
6 #include <random>
7 
8 #include "caffe2/core/context.h"
9 
10 namespace caffe2 {
11 
20 class MKLContext final {
21  public:
22  MKLContext() : random_seed_(math::randomNumberSeed()) {}
23  explicit MKLContext(const DeviceOption& option)
24  : random_seed_(
25  option.has_random_seed() ? option.random_seed()
26  : math::randomNumberSeed()) {
27  CAFFE_ENFORCE_EQ(option.device_type(), MKLDNN);
28  }
29 
30  ~MKLContext() {}
31 
32  inline void SwitchToDevice(int stream_id = 0) {}
33  inline bool FinishDeviceComputation() {
34  return true;
35  }
36 
37  inline std::mt19937& RandGenerator() {
38  if (!random_generator_.get()) {
39  random_generator_.reset(new std::mt19937(random_seed_));
40  }
41  return *random_generator_.get();
42  }
43 
44  inline static void* New(size_t nbytes) {
45  return GetCPUAllocator()->New(nbytes);
46  }
47  inline static void Delete(void* data) {
48  GetCPUAllocator()->Delete(data);
49  }
50 
51  // Two copy functions that deals with cross-device copies.
52  template <class SrcContext, class DstContext>
53  inline void CopyBytes(size_t nbytes, const void* src, void* dst);
54 
55  template <typename T, class SrcContext, class DstContext>
56  inline void Copy(size_t n, const T* src, T* dst) {
57  if (std::is_fundamental<T>::value) {
58  CopyBytes<SrcContext, DstContext>(
59  n * sizeof(T),
60  static_cast<const void*>(src),
61  static_cast<void*>(dst));
62  } else {
63  for (int i = 0; i < n; ++i) {
64  dst[i] = src[i];
65  }
66  }
67  }
68 
69  template <class SrcContext, class DstContext>
70  inline void
71  CopyItems(const TypeMeta& meta, size_t n, const void* src, void* dst) {
72  if (meta.copy()) {
73  meta.copy()(src, dst, n);
74  } else {
75  CopyBytes<SrcContext, DstContext>(n * meta.itemsize(), src, dst);
76  }
77  }
78 
79  protected:
80  // TODO(jiayq): instead of hard-coding a generator, make it more flexible.
81  int random_seed_{1701};
82  std::unique_ptr<std::mt19937> random_generator_;
83 };
84 
85 template <>
86 inline void MKLContext::CopyBytes<MKLContext, MKLContext>(
87  size_t nbytes,
88  const void* src,
89  void* dst) {
90  memcpy(dst, src, nbytes);
91 }
92 
93 template <>
94 inline void MKLContext::CopyBytes<CPUContext, MKLContext>(
95  size_t nbytes,
96  const void* src,
97  void* dst) {
98  memcpy(dst, src, nbytes);
99 }
100 
101 template <>
102 inline void MKLContext::CopyBytes<MKLContext, CPUContext>(
103  size_t nbytes,
104  const void* src,
105  void* dst) {
106  memcpy(dst, src, nbytes);
107 }
108 
109 } // namespace caffe2
110 
111 #endif // CAFFE2_UTILS_MKL_CONTEXT_H_
TypedCopy copy() const
Returns the typed copy function pointer for individual iterms.
Definition: typeid.h:133
The MKL Context, which is largely the same as the CPUContext.
Definition: mkl_context.h:20
TypeMeta is a thin class that allows us to store the type of a container such as a blob...
Definition: typeid.h:66
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
const size_t & itemsize() const
Returns the size of the item.
Definition: typeid.h:121