3 #include "caffe2/core/common_gpu.h" 4 #include "caffe2/core/context_gpu.h" 5 #include "caffe2/operators/pool_op.h" 6 #include "caffe2/cuda_rtc/common_rtc.h" 17 const char kMaxPoolForwardNCHWSource[] = R
"( 19 __global__ void %s(const float* bottom_data, float* top_data) { 20 const int nthreads = %d; 21 const int channels = %d; 22 const int height = %d; 24 const int pooled_height = %d; 25 const int pooled_width = %d; 26 const int kernel_h = %d; 27 const int kernel_w = %d; 28 const int stride_h = %d; 29 const int stride_w = %d; 32 for (int index = blockIdx.x * blockDim.x + threadIdx.x; 33 index < nthreads; index += blockDim.x * gridDim.x) { 34 int pw = index %% pooled_width; 35 int ph = (index / pooled_width) %% pooled_height; 36 int c = (index / (pooled_width * pooled_height)) %% channels; 37 int n = index / (pooled_width * pooled_height * channels); 38 int hstart = ph * stride_h - pad_t; 39 int wstart = pw * stride_w - pad_l; 40 int hend = min(hstart + kernel_h, height); 41 int wend = min(wstart + kernel_w, width); 42 hstart = max(hstart, 0); 43 wstart = max(wstart, 0); 44 float maxval = -1.0e37f; 45 const float* bdata_offset = bottom_data + n * channels * height * width; 46 for (int h = hstart; h < hend; ++h) { 47 for (int w = wstart; w < wend; ++w) { 49 bdata_offset[c * height * width + h * width + w], maxval); 52 top_data[index] = maxval; 58 const char kMaxPoolBackwardNCHWSource[] = R
"( 61 const float* const bottom_data, const float* const top_data, 62 const float* const top_diff, float* const bottom_diff) { 63 const int nthreads = %d; 65 const int channels = %d; 66 const int height = %d; 68 const int pooled_height = %d; 69 const int pooled_width = %d; 70 const int kernel_h = %d; 71 const int kernel_w = %d; 72 const int stride_h = %d; 73 const int stride_w = %d; 76 for (int index = blockIdx.x * blockDim.x + threadIdx.x; 77 index < nthreads; index += blockDim.x * gridDim.x) { 78 const int w = index %% width + pad_l; 79 const int h = (index / width) %% height + pad_t; 80 const int c = (index / width / height) %% channels; 81 const int n = index / width / height / channels; 82 const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; 83 const int phend = min(h / stride_h + 1, pooled_height); 84 const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; 85 const int pwend = min(w / stride_w + 1, pooled_width); 86 const int top_offset = 87 (n * channels + c) * pooled_height * pooled_width; 88 bottom_diff[index] = 0; 89 for (int ph = phstart; ph < phend; ++ph) { 90 for (int pw = pwstart; pw < pwend; ++pw) { 91 int top_local_offset = top_offset + ph * pooled_width + pw; 92 if (bottom_data[index] == top_data[top_local_offset]) { 93 bottom_diff[index] += top_diff[top_local_offset]; 102 class MaxPoolRTCFunction :
public CudaRTCFunction<MaxPoolRTCFunction> {
104 MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
106 template <
typename... Args>
107 string KernelName(Args... args) {
return name_; }
109 template <
typename... Args>
110 string GetSource(Args... args);
116 class MaxPoolGradientRTCFunction
117 :
public CudaRTCFunction<MaxPoolGradientRTCFunction> {
119 MaxPoolGradientRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
121 template <
typename... Args>
122 string KernelName(Args... args) {
return name_; }
124 template <
typename... Args>
125 string GetSource(Args... args);
133 string MaxPoolRTCFunction::GetSource(
134 const int output_size,
138 const int pooled_height,
139 const int pooled_width,
147 int nbytes = snprintf(
148 buffer, 65536, kMaxPoolForwardNCHWSource, name_.c_str(), output_size,
149 channels, height, width, pooled_height, pooled_width, kernel_h, kernel_w,
150 stride_h, stride_w, pad_t, pad_l);
151 DCHECK_GE(nbytes, 0);
152 DCHECK_LT(nbytes, 65536);
153 return string(buffer);
157 string MaxPoolGradientRTCFunction::GetSource(
158 const int output_size,
163 const int pooled_height,
164 const int pooled_width,
172 int nbytes = snprintf(
173 buffer, 65536, kMaxPoolBackwardNCHWSource, name_.c_str(), output_size,
174 num, channels, height, width, pooled_height, pooled_width, kernel_h,
175 kernel_w, stride_h, stride_w, pad_t, pad_l);
176 DCHECK_GE(nbytes, 0);
177 DCHECK_LT(nbytes, 65536);
178 return string(buffer);
187 : ConvPoolOpBase<CUDAContext>(operator_def, ws) {
189 order_, StorageOrder::NCHW,
"Currently only NCHW is supported.");
193 bool RunOnDeviceWithOrderNCHW()
override {
196 ConvPoolOpBase::SetOutputSize(X, Y, X.dim32(1));
198 if (input_dims_ != X.dims()) {
200 VLOG(1) <<
"MaxPool RTC recompiling";
201 CAFFE_ENFORCE_LT(Y->size(), std::numeric_limits<int>::max());
203 static_cast<int>(Y->size()),
215 input_dims_ = X.dims();
219 1, 1, 0, context_.cuda_stream(),
220 X.data<
float>(), Y->mutable_data<
float>());
224 bool RunOnDeviceWithOrderNHWC()
override {
225 LOG(FATAL) <<
"Not implemented.";
230 MaxPoolRTCFunction func_;
231 vector<TIndex> input_dims_;
237 : ConvPoolOpBase<CUDAContext>(operator_def, ws) {
239 order_, StorageOrder::NCHW,
"Currently only NCHW is supported.");
243 bool RunOnDeviceWithOrderNCHW()
override {
247 CAFFE_ENFORCE_EQ(dY.ndim(), 4);
248 auto* dX = Output(0);
250 ConvPoolOpBase<CUDAContext>::ComputePads({X.dim32(2), X.dim32(3)});
251 if (input_dims_ != X.dims()) {
252 VLOG(1) <<
"MaxPoolGradient RTC recompiling";
253 CAFFE_ENFORCE_LT(X.size(), std::numeric_limits<int>::max());
255 static_cast<int>(X.size()),
268 input_dims_ = X.dims();
270 func_.Launch(
CAFFE_GET_BLOCKS(X.size()), 1, 1, CAFFE_CUDA_NUM_THREADS, 1, 1,
271 0, context_.cuda_stream(),
272 X.data<
float>(), Y.data<
float>(), dY.data<
float>(),
273 dX->mutable_data<
float>());
277 bool RunOnDeviceWithOrderNHWC()
override {
278 LOG(FATAL) <<
"Not implemented.";
283 MaxPoolGradientRTCFunction func_;
284 vector<TIndex> input_dims_;
288 REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC,
MaxPoolRTCOp);
289 REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPoolGradient, NVRTC,
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Simple registry implementation in Caffe2 that uses static variables to register object creators durin...
int CAFFE_GET_BLOCKS(const int N)
Compute the number of blocks needed to run N threads.