4 from __future__
import absolute_import
5 from __future__
import division
6 from __future__
import print_function
10 Utility for creating ResNets 11 See "Deep Residual Learning for Image Recognition" by He, Zhang et. al. 2015 17 Helper class for constructing residual blocks. 20 def __init__(self, model, prev_blob, no_bias, is_test, spatial_bn_mom=0.9):
27 self.
no_bias = 1
if no_bias
else 0
29 def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0):
36 weight_init=(
"MSRAFill", {}),
51 def add_spatial_bn(self, num_filters):
63 Add a "bottleneck" component as decribed in He et. al. Figure 3 (right) 71 spatial_batch_norm=True,
84 if spatial_batch_norm:
94 stride=(1
if down_sampling
is False else 2),
98 if spatial_batch_norm:
103 last_conv = self.
add_conv(base_filters, output_filters, kernel=1)
104 if spatial_batch_norm:
110 if (output_filters > input_filters):
111 shortcut_blob = self.
model.Conv(
116 weight_init=(
"MSRAFill", {}),
118 stride=(1
if down_sampling
is False else 2),
121 if spatial_batch_norm:
122 shortcut_blob = self.
model.SpatialBN(
124 'shortcut_projection_%d_spatbn' % self.
comp_count,
132 [shortcut_blob, last_conv],
141 def add_simple_block(
146 spatial_batch_norm=True
156 stride=(1
if down_sampling
is False else 2),
160 if spatial_batch_norm:
164 last_conv = self.
add_conv(num_filters, num_filters, kernel=3, pad=1)
165 if spatial_batch_norm:
169 if (num_filters != input_filters):
170 shortcut_blob = self.
model.Conv(
175 weight_init=(
"MSRAFill", {}),
177 stride=(1
if down_sampling
is False else 2),
180 if spatial_batch_norm:
181 shortcut_blob = self.
model.SpatialBN(
183 'shortcut_projection_%d_spatbn' % self.
comp_count,
190 [shortcut_blob, last_conv],
216 model.Conv(data,
'conv1', num_input_channels, 64, weight_init=(
"MSRAFill", {}),
217 kernel=conv1_kernel, stride=conv1_stride, pad=3, no_bias=no_bias)
219 model.SpatialBN(
'conv1',
'conv1_spatbn_relu', 64,
220 epsilon=1e-3, momentum=0.1, is_test=is_test)
221 model.Relu(
'conv1_spatbn_relu',
'conv1_spatbn_relu')
222 model.MaxPool(
'conv1_spatbn_relu',
'pool1', kernel=3, stride=2)
226 is_test=is_test, spatial_bn_mom=0.1)
229 builder.add_bottleneck(64, 64, 256)
230 builder.add_bottleneck(256, 64, 256)
231 builder.add_bottleneck(256, 64, 256)
234 builder.add_bottleneck(256, 128, 512, down_sampling=
True)
235 for i
in range(1, 4):
236 builder.add_bottleneck(512, 128, 512)
239 builder.add_bottleneck(512, 256, 1024, down_sampling=
True)
240 for i
in range(1, 6):
241 builder.add_bottleneck(1024, 256, 1024)
244 builder.add_bottleneck(1024, 512, 2048, down_sampling=
True)
245 builder.add_bottleneck(2048, 512, 2048)
246 builder.add_bottleneck(2048, 512, 2048)
249 final_avg = model.AveragePool(
250 builder.prev_blob,
'final_avg', kernel=final_avg_kernel, stride=1,
254 last_out = model.FC(final_avg,
'last_out_L{}'.format(num_labels),
261 if (label
is not None):
262 (softmax, loss) = model.SoftmaxWithLoss(
267 return (softmax, loss)
270 return model.Softmax(last_out,
"softmax")
274 model, data, num_input_channels, num_groups, num_labels, is_test=False
277 Create residual net for smaller images (sec 4.2 of He et. al (2015)) 278 num_groups = 'n' in the paper 281 model.Conv(data,
'conv1', num_input_channels, 16, kernel=3, stride=1)
282 model.SpatialBN(
'conv1',
'conv1_spatbn', 16, epsilon=1e-3, is_test=is_test)
283 model.Relu(
'conv1_spatbn',
'relu1')
286 filters = [16, 32, 64]
290 for groupidx
in range(0, 3):
291 for blockidx
in range(0, 2 * num_groups):
292 builder.add_simple_block(
293 prev_filters
if blockidx == 0
else filters[groupidx],
295 down_sampling=(
True if blockidx == 0
and 296 groupidx > 0
else False))
297 prev_filters = filters[groupidx]
300 model.AveragePool(builder.prev_blob,
'final_avg', kernel=8, stride=1)
301 model.FC(
'final_avg',
'last_out', 64, num_labels)
302 softmax = model.Softmax(
'last_out',
'softmax')
def add_spatial_bn(self, num_filters)
def add_conv(self, in_filters, out_filters, kernel, stride=1, pad=0)
def create_resnet_32x32(model, data, num_input_channels, num_groups, num_labels, is_test=False)