Caffe2 - Python API
A deep learning, cross platform ML framework
context.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import threading
9 
10 
11 class ContextInfo(object):
12  def __init__(self, cls, allow_default, arg_name):
13  self.cls = cls
14  self.allow_default = allow_default
15  self.arg_name = arg_name
16  self._local_stack = threading.local()
17 
18  @property
19  def _stack(self):
20  if not hasattr(self._local_stack, 'obj'):
21  self._local_stack.obj = []
22  return self._local_stack.obj
23 
24  def enter(self, value):
25  self._stack.append(value)
26 
27  def exit(self, value):
28  assert len(self._stack) > 0, 'Context %s is empty.' % self.cls
29  assert self._stack.pop() == value
30 
31  def get_active(self, required=True):
32  if len(self._stack) == 0:
33  if not required:
34  return None
35  assert self.allow_default, (
36  'Context %s is required but none is active.' % self.cls)
37  self.enter(self.cls())
38  return self._stack[-1]
39 
40 
41 class ContextManager(object):
42  def __init__(self):
43  self._ctxs = {}
44 
45  def register(self, ctx_info):
46  assert isinstance(ctx_info, ContextInfo)
47  assert (ctx_info.cls not in self._ctxs), (
48  'Context %s already registered' % ctx_info.cls)
49  self._ctxs[ctx_info.cls] = ctx_info
50 
51  def get(self, cls):
52  assert cls in self._ctxs, 'Context %s not registered.' % cls
53  return self._ctxs[cls]
54 
55 
56 _CONTEXT_MANAGER = ContextManager()
57 
58 
59 def context_manager():
60  global _CONTEXT_MANAGER
61  return _CONTEXT_MANAGER
62 
63 
64 def __enter__(self):
65  if self._prev_enter is not None:
66  self._prev_enter()
67  context_manager().get(self._ctx_class).enter(self)
68  return self
69 
70 
71 def __exit__(self, *args):
72  context_manager().get(self._ctx_class).exit(self)
73  if self._prev_exit is not None:
74  self._prev_exit(*args)
75 
76 
77 @classmethod
78 def current(cls, value=None, required=True):
79  return get_active_context(cls, value, required)
80 
81 
82 class define_context(object):
83  def __init__(self, arg_name=None, allow_default=False):
84  self.arg_name = arg_name
85  self.allow_default = allow_default
86 
87  def __call__(self, cls):
88  assert not hasattr(cls, '_ctx_class'), (
89  '%s parent class (%s) already defines context.' % (
90  cls, cls._ctx_class))
91  context_manager().register(
92  ContextInfo(cls, self.allow_default, self.arg_name))
93  cls._prev_enter = cls.__enter__ if hasattr(cls, '__enter__') else None
94  cls._prev_exit = cls.__exit__ if hasattr(cls, '__exit__') else None
95  cls._ctx_class = cls
96  cls.__enter__ = __enter__
97  cls.__exit__ = __exit__
98  cls.current = current
99  return cls
100 
101 
102 def get_active_context(cls, val=None, required=True):
103  ctx_info = context_manager().get(cls)
104  if val is not None:
105  assert isinstance(val, cls), (
106  'Wrong context type. Expected: %s, got %s.' % (cls, type(val)))
107  return val
108  return ctx_info.get_active(required=required)
def _stack(self)
Definition: context.py:19
def enter(self, value)
Definition: context.py:24