Caffe2 - Python API
A deep learning, cross platform ML framework
resnet.py
1 
3 
4 from __future__ import absolute_import
5 from __future__ import division
6 from __future__ import print_function
7 
8 
9 '''
10 Utility for creating ResNets
11 See "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015
12 '''
13 
14 
15 class ResNetBuilder():
16  '''
17  Helper class for constructing residual blocks.
18  '''
19 
20  def __init__(self, model, prev_blob, no_bias, is_test, spatial_bn_mom=0.9):
21  self.model = model
22  self.comp_count = 0
23  self.comp_idx = 0
24  self.prev_blob = prev_blob
25  self.is_test = is_test
26  self.spatial_bn_mom = spatial_bn_mom
27  self.no_bias = 1 if no_bias else 0
28 
29  def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0):
30  self.comp_idx += 1
31  self.prev_blob = self.model.Conv(
32  self.prev_blob,
33  'comp_%d_conv_%d' % (self.comp_count, self.comp_idx),
34  in_filters,
35  out_filters,
36  weight_init=("MSRAFill", {}),
37  kernel=kernel,
38  stride=stride,
39  pad=pad,
40  no_bias=self.no_bias,
41  )
42  return self.prev_blob
43 
44  def add_relu(self):
45  self.prev_blob = self.model.Relu(
46  self.prev_blob,
47  self.prev_blob, # in-place
48  )
49  return self.prev_blob
50 
51  def add_spatial_bn(self, num_filters):
52  self.prev_blob = self.model.SpatialBN(
53  self.prev_blob,
54  'comp_%d_spatbn_%d' % (self.comp_count, self.comp_idx),
55  num_filters,
56  epsilon=1e-3,
57  momentum=self.spatial_bn_mom,
58  is_test=self.is_test,
59  )
60  return self.prev_blob
61 
62  '''
63  Add a "bottleneck" component as decribed in He et. al. Figure 3 (right)
64  '''
65  def add_bottleneck(
66  self,
67  input_filters, # num of feature maps from preceding layer
68  base_filters, # num of filters internally in the component
69  output_filters, # num of feature maps to output
70  down_sampling=False,
71  spatial_batch_norm=True,
72  ):
73  self.comp_idx = 0
74  shortcut_blob = self.prev_blob
75 
76  # 1x1
77  self.add_conv(
78  input_filters,
79  base_filters,
80  kernel=1,
81  stride=1
82  )
83 
84  if spatial_batch_norm:
85  self.add_spatial_bn(base_filters)
86 
87  self.add_relu()
88 
89  # 3x3 (note the pad, required for keeping dimensions)
90  self.add_conv(
91  base_filters,
92  base_filters,
93  kernel=3,
94  stride=(1 if down_sampling is False else 2),
95  pad=1
96  )
97 
98  if spatial_batch_norm:
99  self.add_spatial_bn(base_filters)
100  self.add_relu()
101 
102  # 1x1
103  last_conv = self.add_conv(base_filters, output_filters, kernel=1)
104  if spatial_batch_norm:
105  last_conv = self.add_spatial_bn(output_filters)
106 
107  # Summation with input signal (shortcut)
108  # If we need to increase dimensions (feature maps), need to
109  # do do a projection for the short cut
110  if (output_filters > input_filters):
111  shortcut_blob = self.model.Conv(
112  shortcut_blob,
113  'shortcut_projection_%d' % self.comp_count,
114  input_filters,
115  output_filters,
116  weight_init=("MSRAFill", {}),
117  kernel=1,
118  stride=(1 if down_sampling is False else 2),
119  no_bias=self.no_bias,
120  )
121  if spatial_batch_norm:
122  shortcut_blob = self.model.SpatialBN(
123  shortcut_blob,
124  'shortcut_projection_%d_spatbn' % self.comp_count,
125  output_filters,
126  epsilon=1e-3,
127  momentum=self.spatial_bn_mom,
128  is_test=self.is_test,
129  )
130 
131  self.prev_blob = self.model.Sum(
132  [shortcut_blob, last_conv],
133  'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
134  )
135  self.comp_idx += 1
136  self.add_relu()
137 
138  # Keep track of number of high level components if this ResNetBuilder
139  self.comp_count += 1
140 
141  def add_simple_block(
142  self,
143  input_filters,
144  num_filters,
145  down_sampling=False,
146  spatial_batch_norm=True
147  ):
148  self.comp_idx = 0
149  shortcut_blob = self.prev_blob
150 
151  # 3x3
152  self.add_conv(
153  input_filters,
154  num_filters,
155  kernel=3,
156  stride=(1 if down_sampling is False else 2),
157  pad=1
158  )
159 
160  if spatial_batch_norm:
161  self.add_spatial_bn(num_filters)
162  self.add_relu()
163 
164  last_conv = self.add_conv(num_filters, num_filters, kernel=3, pad=1)
165  if spatial_batch_norm:
166  last_conv = self.add_spatial_bn(num_filters)
167 
168  # Increase of dimensions, need a projection for the shortcut
169  if (num_filters != input_filters):
170  shortcut_blob = self.model.Conv(
171  shortcut_blob,
172  'shortcut_projection_%d' % self.comp_count,
173  input_filters,
174  num_filters,
175  weight_init=("MSRAFill", {}),
176  kernel=1,
177  stride=(1 if down_sampling is False else 2),
178  no_bias=self.no_bias,
179  )
180  if spatial_batch_norm:
181  shortcut_blob = self.model.SpatialBN(
182  shortcut_blob,
183  'shortcut_projection_%d_spatbn' % self.comp_count,
184  num_filters,
185  epsilon=1e-3,
186  is_test=self.is_test,
187  )
188 
189  self.prev_blob = self.model.Sum(
190  [shortcut_blob, last_conv],
191  'comp_%d_sum_%d' % (self.comp_count, self.comp_idx)
192  )
193  self.comp_idx += 1
194  self.add_relu()
195 
196  # Keep track of number of high level components if this ResNetBuilder
197  self.comp_count += 1
198 
199 
200 # The conv1 and final_avg kernel/stride args provide a basic mechanism for
201 # adapting resnet50 for different sizes of input images.
202 def create_resnet50(
203  model,
204  data,
205  num_input_channels,
206  num_labels,
207  label=None,
208  is_test=False,
209  no_loss=False,
210  no_bias=0,
211  conv1_kernel=7,
212  conv1_stride=2,
213  final_avg_kernel=7,
214 ):
215  # conv1 + maxpool
216  model.Conv(data, 'conv1', num_input_channels, 64, weight_init=("MSRAFill", {}),
217  kernel=conv1_kernel, stride=conv1_stride, pad=3, no_bias=no_bias)
218 
219  model.SpatialBN('conv1', 'conv1_spatbn_relu', 64,
220  epsilon=1e-3, momentum=0.1, is_test=is_test)
221  model.Relu('conv1_spatbn_relu', 'conv1_spatbn_relu')
222  model.MaxPool('conv1_spatbn_relu', 'pool1', kernel=3, stride=2)
223 
224  # Residual blocks...
225  builder = ResNetBuilder(model, 'pool1', no_bias=no_bias,
226  is_test=is_test, spatial_bn_mom=0.1)
227 
228  # conv2_x (ref Table 1 in He et al. (2015))
229  builder.add_bottleneck(64, 64, 256)
230  builder.add_bottleneck(256, 64, 256)
231  builder.add_bottleneck(256, 64, 256)
232 
233  # conv3_x
234  builder.add_bottleneck(256, 128, 512, down_sampling=True)
235  for i in range(1, 4):
236  builder.add_bottleneck(512, 128, 512)
237 
238  # conv4_x
239  builder.add_bottleneck(512, 256, 1024, down_sampling=True)
240  for i in range(1, 6):
241  builder.add_bottleneck(1024, 256, 1024)
242 
243  # conv5_x
244  builder.add_bottleneck(1024, 512, 2048, down_sampling=True)
245  builder.add_bottleneck(2048, 512, 2048)
246  builder.add_bottleneck(2048, 512, 2048)
247 
248  # Final layers
249  final_avg = model.AveragePool(
250  builder.prev_blob, 'final_avg', kernel=final_avg_kernel, stride=1,
251  )
252 
253  # Final dimension of the "image" is reduced to 7x7
254  last_out = model.FC(final_avg, 'last_out_L{}'.format(num_labels),
255  2048, num_labels)
256 
257  if no_loss:
258  return last_out
259 
260  # If we create model for training, use softmax-with-loss
261  if (label is not None):
262  (softmax, loss) = model.SoftmaxWithLoss(
263  [last_out, label],
264  ["softmax", "loss"],
265  )
266 
267  return (softmax, loss)
268  else:
269  # For inference, we just return softmax
270  return model.Softmax(last_out, "softmax")
271 
272 
274  model, data, num_input_channels, num_groups, num_labels, is_test=False
275 ):
276  '''
277  Create residual net for smaller images (sec 4.2 of He et. al (2015))
278  num_groups = 'n' in the paper
279  '''
280  # conv1 + maxpool
281  model.Conv(data, 'conv1', num_input_channels, 16, kernel=3, stride=1)
282  model.SpatialBN('conv1', 'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test)
283  model.Relu('conv1_spatbn', 'relu1')
284 
285  # Number of blocks as described in sec 4.2
286  filters = [16, 32, 64]
287 
288  builder = ResNetBuilder(model, 'relu1', is_test=is_test)
289  prev_filters = 16
290  for groupidx in range(0, 3):
291  for blockidx in range(0, 2 * num_groups):
292  builder.add_simple_block(
293  prev_filters if blockidx == 0 else filters[groupidx],
294  filters[groupidx],
295  down_sampling=(True if blockidx == 0 and
296  groupidx > 0 else False))
297  prev_filters = filters[groupidx]
298 
299  # Final layers
300  model.AveragePool(builder.prev_blob, 'final_avg', kernel=8, stride=1)
301  model.FC('final_avg', 'last_out', 64, num_labels)
302  softmax = model.Softmax('last_out', 'softmax')
303  return softmax
def add_spatial_bn(self, num_filters)
Definition: resnet.py:51
def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0)
Definition: resnet.py:29
def create_resnet_32x32(model, data, num_input_channels, num_groups, num_labels, is_test=False)
Definition: resnet.py:275
def add_relu(self)
Definition: resnet.py:44