Caffe2 - Python API
A deep learning, cross platform ML framework
checkpoint.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import os
9 import logging
10 from caffe2.python import core, context
11 from caffe2.python.net_builder import ops
12 from caffe2.python.task import Node, Task, TaskGroup, TaskOutput, WorkspaceType
13 
14 logger = logging.getLogger(__name__)
15 logger.setLevel(logging.INFO)
16 
17 # The name of the special net that is used to store all the blob names in the
18 # workspace.
19 __BLOB_NAMES_NET__ = 'get_blob_list'
20 
22 class Job(object):
23  """
24  A Job defines three TaskGroups: the `init_group`, the `epoch_group` and the
25  `exit_group` which will be run by a JobRunner.
26 
27  The `init_group` will be run only once at startup. Its role is to
28  initialize globally persistent blobs such as model weights, accumulators
29  and data file lists.
30 
31  The `epoch_group` will be run in a loop after init_group. The loop will
32  exit when any of the stop signals added with `add_stop_signal` is True
33  at the end of an epoch.
34 
35  The `exit_group` will be run only once at the very end of the job, when one
36  of the stopping criterias for `epoch_group` was met. The role of this group
37  is save the results of training in the end of the job.
38 
39  Jobs are context-driven, so that Tasks can be added to the active Job
40  without having to explicitly pass the job object around.
41 
42  Example of usage:
43 
44  def build_reader(partitions):
45  with Job.current().init_group:
46  reader = HiveReader(init_reader, ..., partitions)
47  Task(step=init_reader)
48  with Job.current().epoch_group:
49  limited_reader = ReaderWithLimit(reader, num_iter=10000)
50  data_queue = pipe(limited_reader, num_threads=8)
51  Job.current().add_stop_signal(limited_reader.data_finished())
52  return data_queue
53 
54  def build_hogwild_trainer(reader, model):
55  with Job.current().init_group:
56  Task(step=model.param_init_net)
57  with Job.current().epoch_group:
58  pipe(reader, processor=model, num_threads=8)
59  with Job.current().exit_group:
60  Task(step=model.save_model_net)
61 
62  with Job() as job:
63  reader = build_reader(partitions)
64  model = build_model(params)
65  build_hogwild_trainer(reader, model)
66  """
67  def __init__(self,
68  init_group=None, epoch_group=None,
69  exit_group=None, stop_signals=None,
70  nodes_to_checkpoint=None):
71  self.init_group = init_group or TaskGroup(
72  workspace_type=WorkspaceType.GLOBAL)
73  self.epoch_group = epoch_group or TaskGroup()
74  self.exit_group = exit_group or TaskGroup()
75  self.stop_signals = stop_signals or []
76  self._nodes_to_checkpoint = nodes_to_checkpoint
77 
78  def nodes_to_checkpoint(self):
79  if self._nodes_to_checkpoint:
80  return self._nodes_to_checkpoint
81  else:
82  return self.init_group.used_nodes()
83 
84  def compile(self, session_class):
85  return Job(
86  init_group=session_class.compile(self.init_group),
87  epoch_group=session_class.compile(self.epoch_group),
88  exit_group=session_class.compile(self.exit_group),
89  stop_signals=self.stop_signals,
90  nodes_to_checkpoint=self.nodes_to_checkpoint())
91 
92  def __enter__(self):
93  self.epoch_group.__enter__()
94  return self
95 
96  def __exit__(self, *args):
97  self.epoch_group.__exit__()
98 
99  def add_stop_signal(self, output):
100  if isinstance(output, core.BlobReference):
101  t = Task(outputs=[output], group=self.epoch_group)
102  output = t.outputs()[0]
103  assert isinstance(output, TaskOutput)
104  self.stop_signals.append(output)
105 
106 
107 class CheckpointManager(object):
108  """
109  Controls saving and loading of workspaces on every epoch boundary of a job.
110  If a CheckpointManager instance is passed to JobRunner, then JobRunner will
111  call `init`, `read` and `save` at different moments in between epoch runs.
112  """
113  def __init__(self, db, db_type):
114  self._db = db
115  self._db_type = db_type
116  # make sure these blobs are the first in the checkpoint file.
117  self._net = core.Net('!!checkpoint_mngr')
118  self._blob_names = self._net.AddExternalInput('blob_names')
119  self._names_output = None
120 
121  def init(self, nodes=None, retrieve_from_epoch=None):
122  """
123  Build a Task that will be run once after the job's `init_group` is run.
124  This task will determine which blobs need to be checkpointed.
125  If retrieve_from_epoch is not None, then the checkpoint metadata is
126  retrieved from a previously saved checkpoint.
127  """
128  assert nodes is None or len(nodes) == 1, (
129  'CheckpointManager only supports single node.')
130  with Task(outputs=[self._blob_names]) as task:
131  if retrieve_from_epoch is None:
132  ops.GetAllBlobNames(
133  [],
134  self._blob_names,
135  include_shared=False)
136  else:
137  ops.Load(
138  [], self._blob_names,
139  db=self._db_name(retrieve_from_epoch),
140  db_type=self._db_type,
141  absolute_path=True)
142  self._names_output = task.outputs()[0]
143  return task
144 
145  def blob_list(self):
146  assert self._names_output
147  return self._names_output.fetch().tolist()
148 
149  def _db_name(self, epoch):
150  return '%s.%06d' % (self._db, epoch)
151 
152  def load(self, epoch):
153  """
154  Build a Task that will be run by JobRunner when the job is to be
155  resumed from a given epoch. This task will run a Load op that will
156  load and deserialize all relevant blobs from a persistent storage.
157  """
158  logger.info('Load from %s' % self._db_name(epoch))
159  with Task() as task:
160  ops.Load(
161  [],
162  self.blob_list(),
163  db=self._db_name(epoch),
164  db_type=self._db_type,
165  absolute_path=True)
166  return task
167 
168  def load_blobs_from_checkpoint(self, blob_names, epoch):
169  """
170  Builds a Task that loads only the necessary blobs from a checkpoint of
171  the given epoch. The necessary blobs are given in the blob_names
172  argument.
173 
174  Args:
175  blob_names: A list of strings. Each string is the name of a
176  blob.
177  epoch: The checkpoint epoch to load from.
178 
179  Returns:
180  A Task which loads the specified blobs from the checkpoint of the
181  given epoch.
182  """
183  logger.info('Load from %s' % self._db_name(epoch))
184  with Task() as task:
185  ops.Load(
186  [],
187  blob_names,
188  db=self._db_name(epoch),
189  db_type=self._db_type,
190  absolute_path=True,
191  allow_incomplete=True)
192  return task
193 
194  def check_db_exists(self, epoch):
195  logger.info('Check existence of %s' % self._db_name(epoch))
196  with Task() as task:
197  existence = ops.Const(False)
198  ops.DBExists(
199  [],
200  [existence],
201  db_name=self._db_name(epoch),
202  db_type=self._db_type,
203  absolute_path=True)
204  task.add_output(existence)
205  return task
206 
207  def save(self, epoch):
208  """
209  Build a Task that is run once after `init_group` and after each
210  epoch is run. This will execute a Save ops to serialize and persist
211  blobs present in the global workspaace.
212  """
213  logger.info('Save to %s' % self._db_name(epoch))
214  with Task() as task:
215  ops.Save(
216  self.blob_list(), [], db=self._db_name(epoch),
217  db_type=self._db_type, absolute_path=True)
218  return task
219 
220 
222  """
223  Coordinates checkpointing and checkpointing across multiple nodes.
224  Each of `init`, `load` and `save` will build TaskGroups which will
225  trigger checkpointing on each of the nodes involved in a distributed job.
226  """
227  def __init__(
228  self, db_prefix, db_type, node_manager_class=CheckpointManager):
229  self._node_manager_class = node_manager_class
230  self._node_managers = None
231  self._db_prefix = db_prefix
232  self._db_type = db_type
233 
234  def _task_group(self, func, *args, **kw):
235  assert self._node_managers is not None, 'init must be called first.'
236  with TaskGroup(WorkspaceType.GLOBAL) as task_group:
237  for node, manager in self._node_managers:
238  with Node(node):
239  func(manager, *args, **kw)
240  return task_group
241 
242  def init(self, nodes, retrieve_from_epoch=None):
243  if self._node_managers is not None:
244  assert [node for node, _ in self._node_managers] == nodes
245  return
246  self._node_managers = []
247  for node in nodes:
248  with Node(node):
249  manager = self._node_manager_class(
250  db=os.path.join(self._db_prefix, node),
251  db_type=self._db_type)
252  self._node_managers.append((node, manager))
253  return self._task_group(
254  self._node_manager_class.init,
255  nodes=[node],
256  retrieve_from_epoch=retrieve_from_epoch)
257 
258  def load(self, epoch):
259  return self._task_group(self._node_manager_class.load, epoch)
260 
261  def load_blobs_locally(self, nodes, blob_names, epoch, session):
262  """Loads the necessary blobs from the checkpoints to the current node.
263 
264  Args:
265  blob_names: A list of strings. Each string is the name of a
266  blob.
267  epoch: An integer. The checkpoint epoch to load from.
268  session: A Session object to execute the Load ops.
269  """
270  if self._node_managers is not None:
271  assert [node for node, _ in self._node_managers] == nodes
272  else:
273  self._node_managers = []
274  for node in nodes:
275  with Node(node):
276  manager = self._node_manager_class(
277  db=os.path.join(self._db_prefix, node),
278  db_type=self._db_type)
279  self._node_managers.append((node, manager))
280  assert self._node_managers is not None, 'must initialize node managers'
281  for _, manager in self._node_managers:
282  existence_task = manager.check_db_exists(epoch)
283  session.run(existence_task)
284  existence = existence_task.outputs()[0].fetch()
285  if not existence:
286  logger.info('DB %s does not exist!' % manager._db_name(epoch))
287  return False
288  load_task = manager.load_blobs_from_checkpoint(blob_names, epoch)
289  session.run(load_task)
290  logger.info('Successfully loaded from checkpoints.')
291  return True
292 
293  def save(self, epoch):
294  return self._task_group(self._node_manager_class.save, epoch)
295 
296 
297 class JobRunner(object):
298  """
299  Implement the runtime logic for jobs with checkpointing at the level of
300  epoch. Can be used to run either single-host or distributed jobs. Job
301  runner is a callable to be called once from the client, passing a Session
302  as argument. This call will block until the Job execution is complete.
303 
304  If a checkpoint_manager is passed, checkpoints will be taken after
305  initialization and after each epoch execution. If, in addition,
306  `resume_from_epoch` is an epoch number, the corresponding checkpoint will
307  be loaded and job execution will continue from the given epoch. In
308  this case, the job's init_group will not be run.
309 
310  Refer to checkpoint_test.py for an example.
311  """
312  def __init__(self, job, checkpoint_manager=None, resume_from_epoch=None):
313  self.resume_from_epoch = resume_from_epoch
314  self.checkpoint = checkpoint_manager
315  self.job = job
316 
317  def __call__(self, client):
318  from_scratch = self.resume_from_epoch is None
319  if from_scratch:
320  client.run(self.job.init_group)
321 
322  if self.checkpoint:
323  logger.info('Preparing checkpoint ...')
324  client.run(self.checkpoint.init(
325  self.job.nodes_to_checkpoint(),
326  retrieve_from_epoch=self.resume_from_epoch))
327  if from_scratch:
328  logger.info('Saving first checkpoint ...')
329  client.run(self.checkpoint.save(0))
330  logger.info('First checkpoint saved.')
331  else:
332  logger.info('Loading checkpoint for epoch {} ...'.format(
333  self.resume_from_epoch))
334  client.run(self.checkpoint.load(self.resume_from_epoch))
335  logger.info('Checkpoint loaded.')
336 
337  epoch = 1 if from_scratch else self.resume_from_epoch + 1
338  while True:
339  logger.info('Starting epoch %d.' % epoch)
340  client.run(self.job.epoch_group)
341  logger.info('Ran epoch %d.' % epoch)
342  stop_signals = [o.fetch() for o in self.job.stop_signals]
343 
344  if self.checkpoint:
345  logger.info('Saving checkpoint ...')
346  client.run(self.checkpoint.save(epoch))
347  logger.info('Checkpoint saved.')
348 
349  if any(stop_signals):
350  logger.info('Stopping.')
351  break
352  epoch += 1
353  client.run(self.job.exit_group)
354  return epoch
355 
356  def load_blobs_from_checkpoints(self, blob_names, epoch, session):
357  """Loads the necessary blobs from the checkpoints.
358 
359  Checkpoints store the snapshots of the workspace in each node.
360  Sometimes we only need to load a subset of the blobs from the
361  checkpoints. One common scenario is to load only the model blobs from
362  the checkpoints for evaluation purpose. Given the names of the necessary
363  blobs, this function goes over all the checkpoints of all the nodes, but
364  only loads the blobs specified in the blob_names to the current
365  workspace.
366 
367  Args:
368  blob_names: A list of strings. Each string is the name of a
369  blob.
370  epoch: An integer. The checkpoint epoch to load from.
371  session: A Session object to execute the load ops.
372 
373  Raises:
374  ValueError: When the checkpoint manager is invalid.
375  """
376  if not self.checkpoint:
377  raise ValueError('Checkpoint manager is None')
378  logger.info('Loading checkpoint for epoch {} ...'.format(epoch))
379  return self.checkpoint.load_blobs_locally(self.job.nodes_to_checkpoint(),
380  blob_names, epoch, session)
381 
382 
383 def epoch_limiter(num_epochs):
384  """
385  Creates a task that will output True when a given
386  number of epochs has finished.
387  """
388  with Job.current().init_group:
389  init_net = core.Net('epoch_counter_init')
390  counter = init_net.CreateCounter([], init_count=num_epochs - 1)
391  Task(step=init_net)
392  epoch_net = core.Net('epoch_countdown')
393  finished = epoch_net.CountDown(counter)
394  output = Task(step=epoch_net, outputs=finished).outputs()[0]
395  Job.current().add_stop_signal(output)
def _task_group(self, func, args, kw)
Definition: checkpoint.py:234
def load_blobs_locally(self, nodes, blob_names, epoch, session)
Definition: checkpoint.py:261
def _db_name(self, epoch)
Definition: checkpoint.py:149
def load_blobs_from_checkpoints(self, blob_names, epoch, session)
Definition: checkpoint.py:356
def load_blobs_from_checkpoint(self, blob_names, epoch)
Definition: checkpoint.py:168
def nodes_to_checkpoint(self)
Definition: checkpoint.py:78
def init(self, nodes=None, retrieve_from_epoch=None)
Definition: checkpoint.py:121
def epoch_limiter(num_epochs)
Definition: checkpoint.py:383