Caffe2 - C++ API
A deep learning, cross platform ML framework
common_rtc.h
1 #ifndef CAFFE2_CUDA_RTC_COMMON_RTC_H_
2 #define CAFFE2_CUDA_RTC_COMMON_RTC_H_
3 
4 #include <cuda.h>
5 #include <nvrtc.h>
6 
7 #define NVRTC_CHECK(condition) \
8  do { \
9  nvrtcResult result = condition; \
10  if (result != NVRTC_SUCCESS) { \
11  LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \
12  << nvrtcGetErrorString(result); \
13  } \
14  } while(0)
15 
16 namespace caffe2 {
17 
18 template <typename Derived>
20  public:
21  CudaRTCFunction() : module_loaded_(false) {}
22  ~CudaRTCFunction() {
23  if (module_loaded_) {
24  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
25  }
26  }
27 
28  // TODO: this function is nontrivial and since CudaRTCFunction uses CRTP, it
29  // may potentially increase the binary size. In that case, move common parts
30  // into a separate function.
31  template <typename... Args>
32  void Compile(Args... args) {
33  string src = static_cast<Derived*>(this)->GetSource(args...);
34  string name = static_cast<Derived*>(this)->KernelName(args...);
35  VLOG(1) << "function name: " << name;
36  VLOG(1) << "function src:\n" << src;
37  // Actually do the compiling.
38  nvrtcProgram prog;
39  NVRTC_CHECK(nvrtcCreateProgram(
40  &prog, src.c_str(), nullptr, 0, nullptr, nullptr));
41  // Compile the program.
42  // TODO(Yangqing): how to find the current gpu architecture instead of hard
43  // coding it?
44  const char *nvrtc_opts[] = {"--gpu-architecture=compute_35",
45  "--use_fast_math"};
46  nvrtcResult compile_result = nvrtcCompileProgram(
47  prog, 2, nvrtc_opts);
48  if (compile_result != NVRTC_SUCCESS) {
49  size_t log_size;
50  NVRTC_CHECK(nvrtcGetProgramLogSize(prog, &log_size));
51  vector<char> nvrtc_log(log_size);
52  NVRTC_CHECK(nvrtcGetProgramLog(prog, nvrtc_log.data()));
53  LOG(FATAL) << "Compilation failure for nvrtc("
54  << nvrtcGetErrorString(compile_result) << "): \n"
55  << nvrtc_log.data();
56  }
57  size_t ptx_size;
58  NVRTC_CHECK(nvrtcGetPTXSize(prog, &ptx_size));
59  vector<char> nvrtc_ptx(ptx_size);
60  NVRTC_CHECK(nvrtcGetPTX(prog, nvrtc_ptx.data()));
61  NVRTC_CHECK(nvrtcDestroyProgram(&prog));
62  // After compilation, load the module.
63  if (module_loaded_) {
64  CUDA_DRIVERAPI_ENFORCE(cuModuleUnload(module_));
65  }
66  CUDA_DRIVERAPI_ENFORCE(
67  cuModuleLoadDataEx(&module_, nvrtc_ptx.data(), 0, 0, 0));
68  module_loaded_ = true;
69  CUDA_DRIVERAPI_ENFORCE(
70  cuModuleGetFunction(&kernel_, module_, name.c_str()));
71  }
72 
73  template <typename... Args>
74  void Launch(unsigned int gx, unsigned int gy, unsigned int gz,
75  unsigned int bx, unsigned int by, unsigned int bz,
76  unsigned int shared_mem, cudaStream_t stream,
77  Args... args) {
78  CAFFE_ENFORCE(
79  module_loaded_, "Cannot call Launch before a module is loaded.");
80  void * args_voidp[] = {&args...};
81  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
82  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, args_voidp, 0));
83  }
84 
85  void LaunchEx(unsigned int gx, unsigned int gy, unsigned int gz,
86  unsigned int bx, unsigned int by, unsigned int bz,
87  unsigned int shared_mem, cudaStream_t stream,
88  void** extra) {
89  CAFFE_ENFORCE(
90  module_loaded_, "Cannot call Launch before a module is loaded.");
91  CUDA_DRIVERAPI_ENFORCE(cuLaunchKernel(
92  kernel_, gx, gy, gz, bx, by, bz, shared_mem, stream, nullptr, extra));
93  }
94 
95  private:
96  bool module_loaded_;
97  CUmodule module_;
98  CUfunction kernel_;
99 };
100 
101 // TODO: this is in no way unique and is just a hack right now.
102 inline string GetUniqueName() {
103  static constexpr int len = 20;
104  static const char alpha[] =
105  "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ";
106 
107  std::stringstream ss;
108  ss << "_cuda_kernel_";
109  for (int i = 0; i < len; ++i) {
110  ss << alpha[rand() % (sizeof(alpha) - 1)];
111  }
112  return ss.str();
113 }
114 
115 } // namepsace caffe2
116 
117 #endif // CAFFE2_CUDA_RTC_COMMON_RTC_H_
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...