3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
10 from libfb
import pyinit
12 from caffe2.python
import core, cnn, workspace
13 from caffe2.python
import SparseTransformer
14 import caffe2.python.models.resnet
as resnet
18 """Adds the data input part.""" 20 data_uint8, label_orig = model.TensorProtosDBInput(
21 [], [
"data_uint8",
"label_orig"], batch_size=batch_size,
22 db=db, db_type=db_type)
25 data = model.Cast(data_uint8,
"data_nhwc", to=core.DataType.FLOAT)
26 data = model.NHWC2NCHW(data,
"data")
27 data = model.Scale(data, data, scale=float(1. / 256))
28 data = model.StopGradient(data, data)
31 label = model.net.FlattenToVec(label_orig,
"label")
36 """Adds an accuracy op to the model""" 37 accuracy = model.Accuracy([softmax, label],
"accuracy")
42 """Adds training operators to the model.""" 43 xent = model.LabelCrossEntropy([softmax, label],
'xent')
44 loss = model.AveragedLoss(xent,
"loss")
50 model.AddGradientOperators([loss])
52 ITER = model.Iter(
"iter")
57 LR = model.LearningRate(
58 ITER,
"LR", base_lr=-0.01, policy=
"step", stepsize=15000, gamma=0.5)
61 ONE = model.param_init_net.ConstantFill([],
"ONE", shape=[1], value=1.0)
63 for param
in model.params:
66 param_grad = model.param_to_grad[param]
68 model.WeightedSum([param, ONE, param_grad, LR], param)
72 """This adds a few bookkeeping operators that we can inspect later. 74 These operators do not affect the training procedure: they only collect 75 statistics and prints them to file or to logs. 80 model.Print(
'accuracy', [], to_file=1)
81 model.Print(
'loss', [], to_file=1)
84 for param
in model.params:
85 model.Summarize(param, [], to_file=1)
86 model.Summarize(model.param_to_grad[param], [], to_file=1)
94 def AlexNet(model, data, args):
102 (
'ConstantFill', {}),
105 relu1 = model.Relu(conv1,
"conv1")
106 pool1 = model.MaxPool(relu1,
"pool1", kernel=3, stride=2)
114 (
'ConstantFill', {}),
117 relu2 = model.Relu(conv2,
"conv2")
118 pool2 = model.MaxPool(relu2,
"pool2", kernel=3, stride=2)
126 (
'ConstantFill', {}),
129 relu3 = model.Relu(conv3,
"conv3")
137 (
'ConstantFill', {}),
140 relu4 = model.Relu(conv4,
"conv4")
148 (
'ConstantFill', {}),
151 relu5 = model.Relu(conv5,
"conv5")
152 pool5 = model.MaxPool(relu5,
"pool5", kernel=3, stride=2)
154 pool5,
"fc6", 256 * 3 * 3, 4096, (
'XavierFill', {}),
157 relu6 = model.Relu(fc6,
"fc6")
159 relu6,
"fc7", 4096, 4096, (
'XavierFill', {}), (
'ConstantFill', {})
161 relu7 = model.Relu(fc7,
"fc7")
163 relu7,
"fc8", 4096, 10, (
'XavierFill', {}), (
'ConstantFill', {})
165 softmax = model.Softmax(fc8,
"pred")
169 def AlexNet_Prune(model, data, args):
177 (
'ConstantFill', {}),
180 relu1 = model.Relu(conv1,
"conv1")
181 pool1 = model.MaxPool(relu1,
"pool1", kernel=3, stride=2)
189 (
'ConstantFill', {}),
192 relu2 = model.Relu(conv2,
"conv2")
193 pool2 = model.MaxPool(relu2,
"pool2", kernel=3, stride=2)
201 (
'ConstantFill', {}),
204 relu3 = model.Relu(conv3,
"conv3")
212 (
'ConstantFill', {}),
215 relu4 = model.Relu(conv4,
"conv4")
223 (
'ConstantFill', {}),
226 relu5 = model.Relu(conv5,
"conv5")
227 pool5 = model.MaxPool(relu5,
"pool5", kernel=3, stride=2)
228 fc6 = model.FC_Prune(
229 pool5,
"fc6", 256 * 3 * 3, 4096, (
'XavierFill', {}),
230 (
'ConstantFill', {}),
232 threshold=args.prune_thres * 2,
233 need_compress_rate=
True,
236 compress_fc6 = fc6[1]
237 model.Print(compress_fc6, [], to_file=0)
239 relu6 = model.Relu(fc6,
"fc6")
240 fc7 = model.FC_Prune(
241 relu6,
"fc7", 4096, 4096, (
'XavierFill', {}), (
'ConstantFill', {}),
243 threshold=args.prune_thres,
244 need_compress_rate=
True,
247 compress_fc7 = fc7[1]
248 model.Print(compress_fc7, [], to_file=0)
250 relu7 = model.Relu(fc7,
"fc7")
252 relu7,
"fc8", 4096, 10, (
'XavierFill', {}), (
'ConstantFill', {})
254 softmax = model.Softmax(fc8,
"pred")
258 def ConvBNReLUDrop(model, currentblob, outputblob,
259 input_dim, output_dim, drop_ratio=None):
260 currentblob = model.Conv(
267 (
'ConstantFill', {}),
271 currentblob = model.SpatialBN(currentblob,
272 str(currentblob) +
'_bn',
273 output_dim, epsilon=1e-3)
274 currentblob = model.Relu(currentblob, currentblob)
276 currentblob = model.Dropout(currentblob,
277 str(currentblob) +
'_dropout',
282 def VGG(model, data, args):
283 """Adds the VGG-Like kaggle winner Model on Cifar-10 284 The original blog about the model can be found on: 285 http://torch.ch/blog/2015/07/30/cifar.html 287 conv1 = ConvBNReLUDrop(model, data,
'conv1', 3, 64, drop_ratio=0.3)
288 conv2 = ConvBNReLUDrop(model, conv1,
'conv2', 64, 64)
289 pool2 = model.MaxPool(conv2,
'pool2', kernel=2, stride=1)
290 conv3 = ConvBNReLUDrop(model, pool2,
'conv3', 64, 128, drop_ratio=0.4)
291 conv4 = ConvBNReLUDrop(model, conv3,
'conv4', 128, 128)
292 pool4 = model.MaxPool(conv4,
'pool4', kernel=2, stride=2)
294 conv5 = ConvBNReLUDrop(model, pool4,
'conv5', 128, 256, drop_ratio=0.4)
295 conv6 = ConvBNReLUDrop(model, conv5,
'conv6', 256, 256, drop_ratio=0.4)
296 conv7 = ConvBNReLUDrop(model, conv6,
'conv7', 256, 256)
297 pool7 = model.MaxPool(conv7,
'pool7', kernel=2, stride=2)
299 conv8 = ConvBNReLUDrop(model, pool7,
'conv8', 256, 512, drop_ratio=0.4)
300 conv9 = ConvBNReLUDrop(model, conv8,
'conv9', 512, 512, drop_ratio=0.4)
301 conv10 = ConvBNReLUDrop(model, conv9,
'conv10', 512, 512)
302 pool10 = model.MaxPool(conv10,
'pool10', kernel=2, stride=2)
304 conv11 = ConvBNReLUDrop(model, pool10,
'conv11',
305 512, 512, drop_ratio=0.4)
306 conv12 = ConvBNReLUDrop(model, conv11,
'conv12',
307 512, 512, drop_ratio=0.4)
308 conv13 = ConvBNReLUDrop(model, conv12,
'conv13', 512, 512)
309 pool13 = model.MaxPool(conv13,
'pool13', kernel=2, stride=2)
312 pool13,
"fc14", 512, 512, (
'XavierFill', {}),
315 relu14 = model.Relu(fc14,
"fc14")
317 relu14,
"pred", 512, 10, (
'XavierFill', {}),
320 softmax = model.Softmax(pred,
'softmax')
326 Residual net as described in section 4.2 of He at. al. (2015) 331 num_input_channels=3,
339 Residual net as described in section 4.2 of He at. al. (2015) 344 num_input_channels=3,
350 def sparse_transform(model):
351 print(
"====================================================")
352 print(
" Sparse Transformer ")
353 print(
"====================================================")
359 model.net.Proto().op,
362 del model.net.Proto().op[:]
363 model.net.Proto().op.extend(op_list)
366 def test_sparse(test_model):
368 sparse_transform(test_model)
369 sparse_test_accuracy = np.zeros(100)
374 print(
'Sparse Test Accuracy:')
375 print(sparse_test_accuracy)
376 print(
'sparse_test_accuracy: %f' % sparse_test_accuracy.mean())
379 def trainNtest(model_gen, args):
380 print(
"Print running on GPU: %s" % args.gpu)
383 name=
"Cifar_%s" % (args.model),
385 cudnn_exhaustive_search=
True)
387 train_model, batch_size=64,
388 db=args.train_input_path,
389 db_type=args.db_type)
390 softmax = model_gen(train_model, data, args)
395 train_model.param_init_net.RunAllOnGPU()
396 train_model.net.RunAllOnGPU()
413 accuracy = np.zeros(int(epoch_num * epoch_iters / record))
414 loss = np.zeros(int(epoch_num * epoch_iters / record))
416 for e
in range(epoch_num):
417 for i
in range(epoch_iters):
420 count = int(i / record)
423 print(
'Train Loss: {}'.format(loss[count]))
424 print(
'Train Accuracy: {}'.format(accuracy[count]))
432 order=
"NCHW", name=
"cifar10_test", init_params=
False)
434 test_model, batch_size=100,
435 db=args.test_input_path,
436 db_type=args.db_type)
437 softmax = model_gen(test_model, data, args)
442 test_model.param_init_net.RunAllOnGPU()
443 test_model.net.RunAllOnGPU()
450 test_accuracy = np.zeros(100)
457 print(
'Train Accuracy:')
459 print(
'Test Accuracy:')
461 print(
'test_accuracy: %f' % test_accuracy.mean())
463 if args.model ==
'AlexNet_Prune':
464 test_sparse(test_model)
467 MODEL_TYPE_FUNCTIONS = {
469 'AlexNet_Prune': AlexNet_Prune,
471 'ResNet-110': ResNet110,
472 'ResNet-20': ResNet20
475 if __name__ ==
'__main__':
477 sys.argv.append(
'--caffe2_keep_on_shrink')
481 parser = pyinit.FbcodeArgumentParser(description=
'cifar-10 Tutorial')
484 parser.add_argument(
"--model", type=str, default=
'AlexNet',
485 choices=MODEL_TYPE_FUNCTIONS.keys(),
486 help=
"The batch size of benchmark data.")
487 parser.add_argument(
"--prune_thres", type=float, default=0.0001,
488 help=
"Pruning threshold for FC layers.")
489 parser.add_argument(
"--comp_lb", type=float, default=0.02,
490 help=
"Compression Lower Bound for FC layers.")
491 parser.add_argument(
"--gpu", default=
False,
492 help=
"Whether to run on gpu", type=bool)
493 parser.add_argument(
"--train_input_path", type=str,
496 help=
"Path to the database for training data")
497 parser.add_argument(
"--test_input_path", type=str,
500 help=
"Path to the database for test data")
501 parser.add_argument(
"--db_type", type=str,
502 default=
"lmbd", help=
"Database type")
503 args = parser.parse_args()
509 trainNtest(MODEL_TYPE_FUNCTIONS[args.model], args)
def ResNet20(model, data, args)
def VGG(model, data, args)
def create_resnet_32x32(model, data, num_input_channels, num_groups, num_labels, is_test=False)
def AddInput(model, batch_size, db, db_type)
def RunNet(name, num_iter=1)
def AddTrainingOperators(model, softmax, label, nn_model)
def AddBookkeepingOperators(model)
def ResNet110(model, data, args)
def CreateNet(net, overwrite=False, input_blobs=None)
def AddAccuracy(model, softmax, label)