3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 from caffe2.python
import core, context
11 from caffe2.python.net_builder
import ops
12 from caffe2.python.task
import Node, Task, TaskGroup, TaskOutput, WorkspaceType
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
19 __BLOB_NAMES_NET__ =
'get_blob_list' 24 A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the 25 `exit_group` which will be run by a JobRunner. 27 The `init_group` will be run only once at startup. Its role is to 28 initialize globally persistent blobs such as model weights, accumulators 31 The `epoch_group` will be run in a loop after init_group. The loop will 32 exit when any of the stop signals added with `add_stop_signal` is True 33 at the end of an epoch. 35 The `exit_group` will be run only once at the very end of the job, when one 36 of the stopping criterias for `epoch_group` was met. The role of this group 37 is save the results of training in the end of the job. 39 Jobs are context-driven, so that Tasks can be added to the active Job 40 without having to explicitly pass the job object around. 44 def build_reader(partitions): 45 with Job.current().init_group: 46 reader = HiveReader(init_reader, ..., partitions) 47 Task(step=init_reader) 48 with Job.current().epoch_group: 49 limited_reader = ReaderWithLimit(reader, num_iter=10000) 50 data_queue = pipe(limited_reader, num_threads=8) 51 Job.current().add_stop_signal(limited_reader.data_finished()) 54 def build_hogwild_trainer(reader, model): 55 with Job.current().init_group: 56 Task(step=model.param_init_net) 57 with Job.current().epoch_group: 58 pipe(reader, processor=model, num_threads=8) 59 with Job.current().exit_group: 60 Task(step=model.save_model_net) 63 reader = build_reader(partitions) 64 model = build_model(params) 65 build_hogwild_trainer(reader, model) 68 init_group=None, epoch_group=None,
69 exit_group=None, stop_signals=None,
70 nodes_to_checkpoint=None):
72 workspace_type=WorkspaceType.GLOBAL)
78 def nodes_to_checkpoint(self):
84 def compile(self, session_class):
86 init_group=session_class.compile(self.
init_group),
87 epoch_group=session_class.compile(self.
epoch_group),
88 exit_group=session_class.compile(self.
exit_group),
96 def __exit__(self, *args):
99 def add_stop_signal(self, output):
102 output = t.outputs()[0]
103 assert isinstance(output, TaskOutput)
109 Controls saving and loading of workspaces on every epoch boundary of a job. 110 If a CheckpointManager instance is passed to JobRunner, then JobRunner will 111 call `init`, `read` and `save` at different moments in between epoch runs. 113 def __init__(self, db, db_type):
121 def init(self, nodes=None, retrieve_from_epoch=None):
123 Build a Task that will be run once after the job's `init_group` is run. 124 This task will determine which blobs need to be checkpointed. 125 If retrieve_from_epoch is not None, then the checkpoint metadata is 126 retrieved from a previously saved checkpoint. 128 assert nodes
is None or len(nodes) == 1, (
129 'CheckpointManager only supports single node.')
131 if retrieve_from_epoch
is None:
135 include_shared=
False)
139 db=self.
_db_name(retrieve_from_epoch),
149 def _db_name(self, epoch):
150 return '%s.%06d' % (self.
_db, epoch)
154 Build a Task that will be run by JobRunner when the job is to be 155 resumed from a given epoch. This task will run a Load op that will 156 load and deserialize all relevant blobs from a persistent storage. 158 logger.info(
'Load from %s' % self.
_db_name(epoch))
170 Builds a Task that loads only the necessary blobs from a checkpoint of 171 the given epoch. The necessary blobs are given in the blob_names 175 blob_names: A list of strings. Each string is the name of a 177 epoch: The checkpoint epoch to load from. 180 A Task which loads the specified blobs from the checkpoint of the 183 logger.info(
'Load from %s' % self.
_db_name(epoch))
191 allow_incomplete=
True)
194 def check_db_exists(self, epoch):
195 logger.info(
'Check existence of %s' % self.
_db_name(epoch))
197 existence = ops.Const(
False)
204 task.add_output(existence)
209 Build a Task that is run once after `init_group` and after each 210 epoch is run. This will execute a Save ops to serialize and persist 211 blobs present in the global workspaace. 213 logger.info(
'Save to %s' % self.
_db_name(epoch))
217 db_type=self.
_db_type, absolute_path=
True)
223 Coordinates checkpointing and checkpointing across multiple nodes. 224 Each of `init`, `load` and `save` will build TaskGroups which will 225 trigger checkpointing on each of the nodes involved in a distributed job. 228 self, db_prefix, db_type, node_manager_class=CheckpointManager):
234 def _task_group(self, func, *args, **kw):
235 assert self.
_node_managers is not None,
'init must be called first.' 236 with TaskGroup(WorkspaceType.GLOBAL)
as task_group:
239 func(manager, *args, **kw)
242 def init(self, nodes, retrieve_from_epoch=None):
256 retrieve_from_epoch=retrieve_from_epoch)
258 def load(self, epoch):
262 """Loads the necessary blobs from the checkpoints to the current node. 265 blob_names: A list of strings. Each string is the name of a 267 epoch: An integer. The checkpoint epoch to load from. 268 session: A Session object to execute the Load ops. 280 assert self.
_node_managers is not None,
'must initialize node managers' 282 existence_task = manager.check_db_exists(epoch)
283 session.run(existence_task)
284 existence = existence_task.outputs()[0].fetch()
286 logger.info(
'DB %s does not exist!' % manager._db_name(epoch))
288 load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
289 session.run(load_task)
290 logger.info(
'Successfully loaded from checkpoints.')
293 def save(self, epoch):
299 Implement the runtime logic for jobs with checkpointing at the level of 300 epoch. Can be used to run either single-host or distributed jobs. Job 301 runner is a callable to be called once from the client, passing a Session 302 as argument. This call will block until the Job execution is complete. 304 If a checkpoint_manager is passed, checkpoints will be taken after 305 initialization and after each epoch execution. If, in addition, 306 `resume_from_epoch` is an epoch number, the corresponding checkpoint will 307 be loaded and job execution will continue from the given epoch. In 308 this case, the job's init_group will not be run. 310 Refer to checkpoint_test.py for an example. 312 def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None):
317 def __call__(self, client):
320 client.run(self.
job.init_group)
323 logger.info(
'Preparing checkpoint ...')
325 self.
job.nodes_to_checkpoint(),
328 logger.info(
'Saving first checkpoint ...')
330 logger.info(
'First checkpoint saved.')
332 logger.info(
'Loading checkpoint for epoch {} ...'.format(
335 logger.info(
'Checkpoint loaded.')
339 logger.info(
'Starting epoch %d.' % epoch)
340 client.run(self.
job.epoch_group)
341 logger.info(
'Ran epoch %d.' % epoch)
342 stop_signals = [o.fetch()
for o
in self.
job.stop_signals]
345 logger.info(
'Saving checkpoint ...')
347 logger.info(
'Checkpoint saved.')
349 if any(stop_signals):
350 logger.info(
'Stopping.')
353 client.run(self.
job.exit_group)
357 """Loads the necessary blobs from the checkpoints. 359 Checkpoints store the snapshots of the workspace in each node. 360 Sometimes we only need to load a subset of the blobs from the 361 checkpoints. One common scenario is to load only the model blobs from 362 the checkpoints for evaluation purpose. Given the names of the necessary 363 blobs, this function goes over all the checkpoints of all the nodes, but 364 only loads the blobs specified in the blob_names to the current 368 blob_names: A list of strings. Each string is the name of a 370 epoch: An integer. The checkpoint epoch to load from. 371 session: A Session object to execute the load ops. 374 ValueError: When the checkpoint manager is invalid. 377 raise ValueError(
'Checkpoint manager is None')
378 logger.info(
'Loading checkpoint for epoch {} ...'.format(epoch))
379 return self.
checkpoint.load_blobs_locally(self.
job.nodes_to_checkpoint(),
380 blob_names, epoch, session)
385 Creates a task that will output True when a given 386 number of epochs has finished. 388 with Job.current().init_group:
389 init_net =
core.Net(
'epoch_counter_init')
390 counter = init_net.CreateCounter([], init_count=num_epochs - 1)
392 epoch_net =
core.Net(
'epoch_countdown')
393 finished = epoch_net.CountDown(counter)
394 output = Task(step=epoch_net, outputs=finished).outputs()[0]
395 Job.current().add_stop_signal(output)
def _task_group(self, func, args, kw)
def load_blobs_locally(self, nodes, blob_names, epoch, session)
def _db_name(self, epoch)
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
def load_blobs_from_checkpoint(self, blob_names, epoch)
def nodes_to_checkpoint(self)
def init(self, nodes=None, retrieve_from_epoch=None)
def epoch_limiter(num_epochs)