3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
13 from caffe2.python
import core, workspace, experiment_util, data_parallel_model, dyndep
14 from caffe2.python
import timeout_guard, cnn
16 import caffe2.python.models.resnet
as resnet
19 Parallelized multi-GPU distributed trainer for Resnet 50. Can be used to train 20 on imagenet data, for example. 22 To run the trainer in single-machine multi-gpu mode by setting num_shards = 1. 24 To run the trainer in multi-machine multi-gpu mode with M machines, 25 run the same program on all machines, specifying num_shards = M, and 26 shard_id = a unique integer in the set [0, M-1]. 28 For rendezvous (the trainer processes have to know about each other), 29 you can either use a directory path that is visible to all processes 30 (e.g. NFS directory), or use a Redis instance. Use the former by 31 passing the `file_store_path` argument. Use the latter by passing the 32 `redis_host` and `redis_port` arguments. 37 log = logging.getLogger(
"resnet50_trainer")
38 log.setLevel(logging.DEBUG)
45 Image input operator that loads data from reader and 46 applies certain transformations to the images. 48 data, label = model.ImageInput(
51 batch_size=batch_size,
60 data = model.StopGradient(data, data)
65 Add the momentum-SGD update. 67 params = train_model.GetParams()
68 assert(len(params) > 0)
71 param_grad = train_model.param_to_grad[param]
72 param_momentum = train_model.param_init_net.ConstantFill(
73 [param], param +
'_momentum', value=0.0
77 train_model.net.MomentumSGDUpdate(
78 [param_grad, param_momentum, LR, param],
79 [param_grad, param_momentum, param],
96 Run one epoch of the trainer. 97 TODO: add checkpointing here. 100 log.info(
"Starting epoch {}/{}".format(epoch, args.num_epochs))
101 epoch_iters = int(args.epoch_size / total_batch_size / num_shards)
102 for i
in range(epoch_iters):
105 timeout = 600.0
if i == 0
else 60.0
112 fmt =
"Finished iteration {}/{} of epoch {} ({:.2f} images/sec)" 113 log.info(fmt.format(i + 1, epoch_iters, epoch, total_batch_size / dt))
115 num_images = epoch * epoch_iters * total_batch_size
116 prefix =
"gpu_{}".format(train_model._devices[0])
121 if (test_model
is not None):
124 for _
in range(0, 100):
126 for g
in test_model._devices:
128 "gpu_{}".format(g) +
'/accuracy' 131 test_accuracy /= ntests
136 input_count=num_images,
137 batch_count=(i + epoch * epoch_iters),
139 'accuracy': accuracy,
141 'learning_rate': learning_rate,
143 'test_accuracy': test_accuracy,
146 assert loss < 40,
"Exploded gradients :(" 154 if args.gpus
is not None:
155 gpus = [int(x)
for x
in args.gpus.split(
',')]
158 gpus = range(args.num_gpus)
159 num_gpus = args.num_gpus
161 log.info(
"Running on GPUs: {}".format(gpus))
164 total_batch_size = args.batch_size
165 batch_per_device = total_batch_size // num_gpus
167 total_batch_size % num_gpus == 0, \
168 "Number of GPUs must divide batch size" 171 global_batch_size = total_batch_size * args.num_shards
172 epoch_iters = int(args.epoch_size / global_batch_size)
173 args.epoch_size = epoch_iters * global_batch_size
174 log.info(
"Using epoch size: {}".format(args.epoch_size))
181 cudnn_exhaustive_search=
True,
182 ws_nbytes_limit=(args.cudnn_workspace_limit_mb * 1024 * 1024),
185 num_shards = args.num_shards
186 shard_id = args.shard_id
189 store_handler =
"store_handler" 190 if args.redis_host
is not None:
194 "RedisStoreHandlerCreate", [], [store_handler],
195 host=args.redis_host,
196 port=args.redis_port,
204 "FileStoreHandlerCreate", [], [store_handler],
205 path=args.file_store_path,
209 kv_handler=store_handler,
211 num_shards=num_shards,
219 def create_resnet50_model_ops(model, loss_scale):
223 num_input_channels=args.num_channels,
224 num_labels=args.num_labels,
228 loss = model.Scale(loss, scale=loss_scale)
229 model.Accuracy([softmax,
"label"],
"accuracy")
233 def add_parameter_update_ops(model):
234 model.AddWeightDecay(args.weight_decay)
235 ITER = model.Iter(
"ITER")
236 stepsz = int(30 * args.epoch_size / total_batch_size / num_shards)
237 LR = model.net.LearningRate(
240 base_lr=args.base_learning_rate,
248 reader = train_model.CreateDB(
251 db_type=args.db_type,
252 num_shards=num_shards,
256 def add_image_input(model):
260 batch_size=batch_per_device,
261 img_size=args.image_size,
267 input_builder_fun=add_image_input,
268 forward_pass_builder_fun=create_resnet50_model_ops,
269 param_update_builder_fun=add_parameter_update_ops,
271 rendezvous=rendezvous,
272 optimize_gradient_memory=
True,
277 if (args.test_data
is not None):
278 log.info(
"----- Create test net ----")
281 name=
"resnet50_test",
283 cudnn_exhaustive_search=
True 286 test_reader = test_model.CreateDB(
289 db_type=args.db_type,
292 def test_input_fn(model):
296 batch_size=batch_per_device,
297 img_size=args.image_size,
302 input_builder_fun=test_input_fn,
303 forward_pass_builder_fun=create_resnet50_model_ops,
304 param_update_builder_fun=
None,
313 expname =
"resnet50_gpu%d_b%d_L%d_lr%.2f_v2" % (
317 args.base_learning_rate,
323 while epoch < args.num_epochs:
340 parser = argparse.ArgumentParser(
341 description=
"Caffe2: Resnet-50 training" 343 parser.add_argument(
"--train_data", type=str, default=
None,
344 help=
"Path to training data or 'everstore_sampler'",
346 parser.add_argument(
"--test_data", type=str, default=
None,
347 help=
"Path to test data")
348 parser.add_argument(
"--db_type", type=str, default=
"lmdb",
349 help=
"Database type (such as lmdb or leveldb)")
350 parser.add_argument(
"--gpus", type=str,
351 help=
"Comma separated list of GPU devices to use")
352 parser.add_argument(
"--num_gpus", type=int, default=1,
353 help=
"Number of GPU devices (instead of --gpus)")
354 parser.add_argument(
"--num_channels", type=int, default=3,
355 help=
"Number of color channels")
356 parser.add_argument(
"--image_size", type=int, default=227,
357 help=
"Input image size (to crop to)")
358 parser.add_argument(
"--num_labels", type=int, default=1000,
359 help=
"Number of labels")
360 parser.add_argument(
"--batch_size", type=int, default=32,
361 help=
"Batch size, total over all GPUs")
362 parser.add_argument(
"--epoch_size", type=int, default=1500000,
363 help=
"Number of images/epoch, total over all machines")
364 parser.add_argument(
"--num_epochs", type=int, default=1000,
366 parser.add_argument(
"--base_learning_rate", type=float, default=0.1,
367 help=
"Initial learning rate.")
368 parser.add_argument(
"--weight_decay", type=float, default=1e-4,
369 help=
"Weight decay (L2 regularization)")
370 parser.add_argument(
"--cudnn_workspace_limit_mb", type=int, default=64,
371 help=
"CuDNN workspace limit in MBs")
372 parser.add_argument(
"--num_shards", type=int, default=1,
373 help=
"Number of machines in distributed run")
374 parser.add_argument(
"--shard_id", type=int, default=0,
376 parser.add_argument(
"--run_id", type=str,
377 help=
"Unique run identifier (e.g. uuid)")
378 parser.add_argument(
"--redis_host", type=str,
379 help=
"Host of Redis server (for rendezvous)")
380 parser.add_argument(
"--redis_port", type=int, default=6379,
381 help=
"Port of Redis server (for rendezvous)")
382 parser.add_argument(
"--file_store_path", type=str, default=
"/tmp",
383 help=
"Path to directory to use for rendezvous")
384 args = parser.parse_args()
388 if __name__ ==
'__main__':
def RunNet(name, num_iter=1)
def RunEpoch(args, epoch, train_model, test_model, total_batch_size, num_shards, expname, explog)
def create_resnet50(model, data, num_input_channels, num_labels, label=None, is_test=False, no_loss=False, no_bias=0, conv1_kernel=7, conv1_stride=2, final_avg_kernel=7)
def CompleteInTimeOrDie(timeout_secs)
def AddImageInput(model, reader, batch_size, img_size)
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
def CreateNet(net, overwrite=False, input_blobs=None)
def RunOperatorOnce(operator)
def AddMomentumParameterUpdate(train_model, LR)
def Parallelize_GPU(model_helper_obj, input_builder_fun, forward_pass_builder_fun, param_update_builder_fun, devices=range(0, workspace.NumCudaDevices()), rendezvous=None, net_type='dag', broadcast_computed_params=True, optimize_gradient_memory=False)