3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
9 from caffe2.python
import core, workspace
10 from caffe2.python.task
import Cluster, Task, TaskGroup, WorkspaceType
14 """ Wrapper for compiled runnable returned from session.compile() """ 15 def __init__(self, obj, session_class):
22 Allows to run Nets, ExecutionSteps, Plans, Tasks and TaskGroups. 23 A session can potentially run in multiple nodes concurrently. 28 from caffe2.python.task import Task, TaskGroup, WorkspaceType 31 net.Add([net.Const(1), net.Const(2)]) 34 step = core.execution_step('step1', [net2]) 36 with TaskGroup(WorkspaceType.GLOBAL) as init_tg: 38 n1setup = net.Net('n1setup') 39 n1msg = n1setup.Const('Hello from node 1.') 42 with TaskGroup() as private_tg: 49 n2.Print(n2.Const('Hello from node 2.'), 0) 52 session = LocalSession() 56 session.run(private_tg) 60 At the beggining of the session, a global workspace is created and kept 61 alive for the duration of the session. 65 Tasks can be run either directly on the global workspace, or they can 66 instantiate a private child workspace that is released after each run. 69 Tasks running in different nodes in parallel will always run under 70 different workspaces, so it must be assumed that they won't be able to 71 access each other's blobs. On the other hand, tasks running on the same 72 node are guaranteed to run on the same workspace within a run. 84 def compile(cls, runnable):
85 if isinstance(runnable, CompiledRunnable):
86 assert cls == runnable.session_class, (
87 'Runnable was compiled for different session type. ' +
88 'Need: %s, got: %s' % (
89 cls.__name__, runnable.session_class.__name__))
95 if isinstance(runnable, TaskGroup):
98 tg = TaskGroup(workspace_type=WorkspaceType.GLOBAL)
99 if isinstance(runnable, Task):
102 tg.add(Task(step=runnable))
105 tg.add(Task(step=step))
111 def run(self, runnable):
112 assert self.
is_open(),
'Session is closed.' 120 def fetch_output(self, output):
121 raise NotImplementedError()
123 def _run_compiled(self, task_group):
124 raise NotImplementedError()
127 def _compile_task_group(cls, task_group):
134 assert self._open,
'Session already closed.' 137 def __exit__(self, ex_type, value, traceback):
144 Session that runs in a single node. 145 Tasks are all remapped to run in parallel in the 'local' node. 147 Currently, LocalSession runs all parallel tasks in the same workspace, 148 but this behavior may change in the future. Only tasks pointing to the 149 same logical node are guaranteed to always run in the same workspace. 151 def __init__(self, ws=None):
152 Session.__init__(self)
153 self.
_ws = ws
or workspace.C.Workspace.current
156 def _compile_task_group(cls, task_group):
158 task = task_group.to_task()
160 plan.AddStep(task.get_step())
161 return (plan, task.output_list(), task.workspace_type)
163 def _run_compiled(self, compiled):
164 plan, output_list, workspace_type = compiled
168 for name
in output_list.names():
169 self.
_ws.create_blob(str(name))
171 output_list.set_values(outputs, _fetch_func=self.
_fetch_output)
173 workspace.C.Workspace(self.
_ws)
174 if workspace_type == WorkspaceType.PRIVATE
else self.
_ws)
178 def _fetch_output(self, output):
179 return self.
_ws.blobs[str(output)].fetch()
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
def _fetch_output(self, output)
def compile(cls, runnable)
dictionary _compiled_cache
def _run_compiled(self, task_group)
def WorkspaceGuard(workspace_name)
def _compile_task_group(cls, task_group)