Caffe2 - Python API
A deep learning, cross platform ML framework
session.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 from caffe2.python import core, workspace
10 from caffe2.python.task import Cluster, Task, TaskGroup, WorkspaceType
11 
12 
13 class CompiledRunnable(object):
14  """ Wrapper for compiled runnable returned from session.compile() """
15  def __init__(self, obj, session_class):
16  self.obj = obj
17  self.session_class = session_class
18 
19 
20 class Session(object):
21  """
22  Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups.
23  A session can potentially run in multiple nodes concurrently.
24 
25 
26  Example:
27  from core import Net
28  from caffe2.python.task import Task, TaskGroup, WorkspaceType
29 
30  net = Net('test1')
31  net.Add([net.Const(1), net.Const(2)])
32 
33  net2 = net.Clone()
34  step = core.execution_step('step1', [net2])
35 
36  with TaskGroup(WorkspaceType.GLOBAL) as init_tg:
37  with Node('node1'):
38  n1setup = net.Net('n1setup')
39  n1msg = n1setup.Const('Hello from node 1.')
40  Task(step=n1setup)
41 
42  with TaskGroup() as private_tg:
43  with Node('node1'):
44  n1 = net.Net('n1')
45  n1.Print(n1msg, 0)
46  Task(step=n1)
47  with Node('node2'):
48  n2 = net.Net('n2')
49  n2.Print(n2.Const('Hello from node 2.'), 0)
50  Task(step=n2)
51 
52  session = LocalSession()
53  session.run(net)
54  session.run(step)
55  session.run(init_tg)
56  session.run(private_tg)
57 
58 
59  Global Workspace:
60  At the beggining of the session, a global workspace is created and kept
61  alive for the duration of the session.
62 
63 
64  Private Workspace:
65  Tasks can be run either directly on the global workspace, or they can
66  instantiate a private child workspace that is released after each run.
67 
68  Blob visibility:
69  Tasks running in different nodes in parallel will always run under
70  different workspaces, so it must be assumed that they won't be able to
71  access each other's blobs. On the other hand, tasks running on the same
72  node are guaranteed to run on the same workspace within a run.
73  """
74 
75  _compiled_cache = {}
76 
77  def __init__(self):
78  self._open = True
79 
80  def is_open(self):
81  return self._open
82 
83  @classmethod
84  def compile(cls, runnable):
85  if isinstance(runnable, CompiledRunnable):
86  assert cls == runnable.session_class, (
87  'Runnable was compiled for different session type. ' +
88  'Need: %s, got: %s' % (
89  cls.__name__, runnable.session_class.__name__))
90  return runnable
91 
92  if runnable in cls._compiled_cache:
93  return cls._compiled_cache[runnable]
94 
95  if isinstance(runnable, TaskGroup):
96  tg = runnable
97  else:
98  tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
99  if isinstance(runnable, Task):
100  tg.add(runnable)
101  elif isinstance(runnable, core.ExecutionStep):
102  tg.add(Task(step=runnable))
103  else:
104  step = core.execution_step('runnable', runnable)
105  tg.add(Task(step=step))
106  compiled = CompiledRunnable(
107  cls._compile_task_group(tg), session_class=cls)
108  cls._compiled_cache[runnable] = compiled
109  return compiled
110 
111  def run(self, runnable):
112  assert self.is_open(), 'Session is closed.'
113  self._run_compiled(self.compile(runnable).obj)
114 
115  def close(self):
116  if self.is_open():
117  self._do_close()
118  self._open = False
119 
120  def fetch_output(self, output):
121  raise NotImplementedError()
122 
123  def _run_compiled(self, task_group):
124  raise NotImplementedError()
125 
126  @classmethod
127  def _compile_task_group(cls, task_group):
128  return task_group
129 
130  def _do_close(self):
131  pass
132 
133  def __enter__(self):
134  assert self._open, 'Session already closed.'
135  return self
136 
137  def __exit__(self, ex_type, value, traceback):
138  if ex_type is None:
139  self.close()
140 
141 
143  """
144  Session that runs in a single node.
145  Tasks are all remapped to run in parallel in the 'local' node.
146 
147  Currently, LocalSession runs all parallel tasks in the same workspace,
148  but this behavior may change in the future. Only tasks pointing to the
149  same logical node are guaranteed to always run in the same workspace.
150  """
151  def __init__(self, ws=None):
152  Session.__init__(self)
153  self._ws = ws or workspace.C.Workspace.current
154 
155  @classmethod
156  def _compile_task_group(cls, task_group):
157  with Cluster():
158  task = task_group.to_task()
159  plan = core.Plan('task_group_plan')
160  plan.AddStep(task.get_step())
161  return (plan, task.output_list(), task.workspace_type)
162 
163  def _run_compiled(self, compiled):
164  plan, output_list, workspace_type = compiled
165 
166  # make sure the output blobs belong to the parent workspace
167  outputs = []
168  for name in output_list.names():
169  self._ws.create_blob(str(name))
170  outputs.append(core.BlobReference(str(name)))
171  output_list.set_values(outputs, _fetch_func=self._fetch_output)
172  task_ws = (
173  workspace.C.Workspace(self._ws)
174  if workspace_type == WorkspaceType.PRIVATE else self._ws)
175  with workspace.WorkspaceGuard(task_ws):
176  task_ws.run(plan)
177 
178  def _fetch_output(self, output):
179  return self._ws.blobs[str(output)].fetch()
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
Definition: core.py:2018
def _fetch_output(self, output)
Definition: session.py:178
def compile(cls, runnable)
Definition: session.py:84
dictionary _compiled_cache
Definition: session.py:75
def _run_compiled(self, task_group)
Definition: session.py:123
def WorkspaceGuard(workspace_name)
Definition: workspace.py:344
def _do_close(self)
Definition: session.py:130
def _compile_task_group(cls, task_group)
Definition: session.py:127
def is_open(self)
Definition: session.py:80