Caffe2 - Python API
A deep learning, cross platform ML framework
queue_util.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, dataio
9 from caffe2.python.task import TaskGroup
10 
11 
13  def __init__(self, wrapper):
14  assert wrapper.schema is not None, (
15  'Queue needs a schema in order to be read from.')
16  dataio.Reader.__init__(self, wrapper.schema())
17  self._wrapper = wrapper
18 
19  def setup_ex(self, init_net, exit_net):
20  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
21 
22  def read_ex(self, local_init_net, local_finish_net):
23  self._wrapper._new_reader(local_init_net)
24  dequeue_net = core.Net('dequeue')
25  fields, status_blob = dequeue(
26  dequeue_net,
27  self._wrapper.queue(),
28  len(self.schema().field_names()),
29  field_names=self.schema().field_names())
30  return [dequeue_net], status_blob, fields
31 
32 
34  def __init__(self, wrapper):
35  self._wrapper = wrapper
36 
37  def setup_ex(self, init_net, exit_net):
38  exit_net.CloseBlobsQueue([self._wrapper.queue()], 0)
39 
40  def write_ex(self, fields, local_init_net, local_finish_net, status):
41  self._wrapper._new_writer(self.schema(), local_init_net)
42  enqueue_net = core.Net('enqueue')
43  enqueue(enqueue_net, self._wrapper.queue(), fields, status)
44  return [enqueue_net]
45 
46 
48  def __init__(self, handler, schema=None):
49  dataio.Pipe.__init__(self, schema, TaskGroup.LOCAL_SETUP)
50  self._queue = handler
51 
52  def reader(self):
53  return _QueueReader(self)
54 
55  def writer(self):
56  return _QueueWriter(self)
57 
58  def queue(self):
59  return self._queue
60 
61 
63  def __init__(self, capacity, schema=None, name='queue'):
64  # find a unique blob name for the queue
65  net = core.Net(name)
66  queue_blob = net.AddExternalInput(net.NextName('handler'))
67  QueueWrapper.__init__(self, queue_blob, schema)
68  self.capacity = capacity
69  self._setup_done = False
70 
71  def setup(self, global_init_net):
72  assert self._schema, 'This queue does not have a schema.'
73  self._setup_done = True
74  global_init_net.CreateBlobsQueue(
75  [],
76  [self._queue],
77  capacity=self.capacity,
78  num_blobs=len(self._schema.field_names()),
79  field_names=self._schema.field_names())
80 
81 
82 def enqueue(net, queue, data_blobs, status=None):
83  if status is None:
84  status = net.NextName('status')
85  results = net.SafeEnqueueBlobs([queue] + data_blobs, data_blobs + [status])
86  return results[-1]
87 
88 
89 def dequeue(net, queue, num_blobs, status=None, field_names=None):
90  if field_names is not None:
91  assert len(field_names) == num_blobs
92  data_names = [net.NextName(name) for name in field_names]
93  else:
94  data_names = [net.NextName('data', i) for i in range(num_blobs)]
95  if status is None:
96  status = net.NextName('status')
97  results = net.SafeDequeueBlobs(queue, data_names + [status])
98  results = list(results)
99  status_blob = results.pop(-1)
100  return results, status_blob
101 
102 
103 def close_queue(step, *queues):
104  close_net = core.Net("close_queue_net")
105  for queue in queues:
106  close_net.CloseBlobsQueue([queue], 0)
107  close_step = core.execution_step("%s_step" % str(close_net), close_net)
108  return core.execution_step(
109  "%s_wraper_step" % str(close_net),
110  [step, close_step])
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
Definition: core.py:2018
def schema(self)
Definition: dataio.py:163
def __init__(self, schema=None)
Definition: dataio.py:28
def __init__(self, schema=None, obj_key=None)
Definition: dataio.py:263
def schema(self)
Definition: dataio.py:33