Caffe2 - Python API
A deep learning, cross platform ML framework
optimizers.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import context
9 
10 
11 @context.define_context(allow_default=True)
12 class OptimizerContext(object):
13  """
14  Scope driven way to provide optimizers to layers.
15  Optimizer can be fetched through the 'get_optimizer' method.
16  """
17 
18  def __init__(self):
19  self._optimizers = {}
20  self._optimizers_list = []
21 
22  def _rebuild_optimizers(self):
23  self._optimizers = {}
24  for m in self._optimizers_list:
25  self._optimizers.update(m)
26 
27  def get_optimizer(self, name):
28  assert name in self._optimizers, (
29  "{} optimizer is not provided!".format(name))
30  return self._optimizers.get(name)
31 
32  def push_optimizers(self, optimizers):
33  # optimizer override is allowed
34  self._optimizers_list.append(optimizers)
35  self._optimizers.update(optimizers)
36 
37  def pop_optimizers(self):
38  assert len(self._optimizers_list) > 0
39  self._optimizers_list.pop()
40  self._rebuild_optimizers()
41 
42 
43 class Optimizers(object):
44  """
45  Optimizers context to provide optimizers to layers
46  within the context.
47 
48  Example usage:
49  optimizers = {'optim1': optim1, 'optim2': optim2}
50  with Optimizers(optimizers):
51  optim = OptimizerContext.current().get_optimizer('optim1')
52  layer(optim=optim)
53  """
54  def __init__(self, optimizers):
55  self._optimizers = optimizers
56 
57  def __enter__(self):
58  OptimizerContext.current().push_optimizers(self._optimizers)
59  return self
60 
61  def __exit__(self, type, value, traceback):
62  OptimizerContext.current().pop_optimizers()