Caffe2 - Python API
A deep learning, cross platform ML framework
data_workers.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
9 '''
10 This module provides a python-land multithreaded data input mechanism
11 for Caffe2 nets.
12 
13 Basic usage is as follows:
14  coordinator = data_workers.init_data_input_workers(
15  net,
16  ["data", "label"],
17  my_fetch_fun,
18  batch_size=32,
19  input_source_name="train"
20  )
21  ...
22  coordinator.start()
23 
24 First argument is the Caffe2 net (or model helper), and second argument
25 is list of input blobs that are to be fed.
26 
27 Argument 'input_source_name' is used to distinguish different sources of data,
28 such as train or test data. This is to ensure the data does not get mixed up,
29 although two nets would share blobs.
30 
31 To do the actual data loading, one defines a "fetcher function"
32 that has call signature
33  my_fetch_fun(worker_id, batch_size)
34 
35 Optionally, one can define a "init function" that is called once before
36 threads start, and has call signature:
37  my_init_fun(data_coordinator, global_coordinator)
38 
39 This function returns a list of numpy arrays corresponding to the different
40 input blobs. In the example above, it would return two arrays, one for the
41 data blob and another for the labels. These arrays can have arbitrary number
42 of elements (i.e they do not need to match the batch size). The batch size
43 is provided for the function as a hint only.
44 
45 For example, fetcher function could download images from a remote service or
46 load random images from a directory on a file system.
47 
48 For a dummy example, see the data_workers_test unit test.
49 
50 Note that for data_parallel_models, init_data_input_workers will be called
51 for each GPU. Note that the 'coordinator' returned by the function is same
52 each time.
53 '''
54 
55 import Queue
56 import logging
57 import threading
58 import atexit
59 import numpy as np
60 import time
61 import collections
62 
63 from caffe2.python import workspace, core, scope
64 from caffe2.proto import caffe2_pb2
65 
66 log = logging.getLogger("data_workers")
67 log.setLevel(logging.INFO)
68 LOG_INT_SECS = 60
69 
70 
71 def get_worker_ids(num_workers):
72  return range(0, num_workers)
73 
74 
75 def init_data_input_workers(
76  net,
77  input_blob_names,
78  fetch_fun,
79  batch_size,
80  num_worker_threads=2,
81  input_source_name="train",
82  max_buffered_batches=800,
83  init_fun=None,
84  external_loggers=None,
85 ):
86  global global_coordinator
87  device_option = scope.CurrentDeviceScope()
88  if (device_option is None):
89  device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
90 
91  # Create coordinator object
92  coordinator = DataInputCoordinator(
93  net,
94  input_blob_names,
95  batch_size,
96  device_option,
98  input_source_name,
99  global_coordinator.get_queue(input_source_name, max_buffered_batches),
100  init_fun=init_fun,
101  external_loggers=external_loggers,
102  )
103 
104  # Launch fetch worker threads
105  worker_ids = [
106  global_coordinator.get_new_worker_id()
107  for i in range(num_worker_threads)
108  ]
109  workers = [
110  threading.Thread(
111  target=fetcher,
112  name="data_workers fetcher id {}".format(worker_id),
113  args=[coordinator, worker_id, fetch_fun, batch_size, input_blob_names],
114  ) for worker_id in worker_ids
115  ]
116 
117  workers.append(threading.Thread(
118  target=enqueuer,
119  name="Enqueuer {} {}".format(input_source_name, scope.CurrentNameScope()),
120  args=[coordinator]))
121  coordinator._workers = workers
122  global_coordinator.add(coordinator)
123 
124  return global_coordinator
125 
126 
127 class DataInputCoordinator(object):
128  def __init__(self, net, input_blob_names, batch_size,
129  device_option, namescope, input_source_name, queue,
130  init_fun=None, external_loggers=None):
131  self._net = net
132  self._counter = 0
133  self._input_blob_names = input_blob_names
134  self._batch_size = batch_size
135  self._internal_queue = queue
136  self._queues = []
137  self._device_option = device_option
138  self._namescope = namescope
139  self._active = True
140  self._started = False
141  self._workers = []
142  self._input_source_name = input_source_name
144  self._inputs = 0
145  self._prev_seconds = 0
146  self._last_warning = time.time()
147  self._init_fun = init_fun
148  self._metrics = collections.defaultdict(lambda: 0)
149  self._external_loggers = external_loggers
150 
151  def is_active(self):
152  return self._active
153 
154  def init(self, global_coordinator):
155  if self._init_fun:
156  self._init_fun(self, global_coordinator)
157 
158  def _start(self):
159  if self._started:
160  return
161  self._active = True
162  self._started = True
163  self._inputs = 0
164  self._prev_seconds = time.time()
165 
166  for w in self._workers:
167  w.daemon = True
168  w.start()
169 
170  def _stop(self, reason=None):
171  try:
172  self._active = False
173  if reason is not None:
174  log.error("Data input failed due to an error: {}".format(reason))
175 
176  for q in self._queues:
178  core.CreateOperator("CloseBlobsQueue", [q], [])
179  )
180  self._started = False
181  finally:
182  self._log_inputs_per_interval(0, force=True)
183 
184  def _wait_finish(self):
185  print("Wait for workers to die")
186  for w in self._workers:
187  if w != threading.current_thread():
188  w.join(5.0) # don't wait forever, thread may be blocked in i/o
189  success = True
190  for w in self._workers:
191  if w.isAlive():
192  print("Worker {} failed to close while waiting".format(w))
193  success = False
194 
195  print("All workers terminated: {}".format(success))
196  return success
197 
198  def _get(self):
199  while self.is_active():
200  try:
201  return self._internal_queue.get(block=True, timeout=0.5)
202  except Queue.Empty:
203  continue
204  return None
205 
206  def put(self, chunk):
207  if len(chunk) == 0:
208  print("Worker provided zero length input")
209  return
210  while self.is_active():
211  try:
212  qsize = self._internal_queue.qsize()
213  if qsize < 2 and (time.time() - self._last_warning) > LOG_INT_SECS:
214  print("Warning, data loading lagging behind: " +
215  "name={}".format(qsize, self._input_source_name))
216  self._last_warning = time.time()
217  self._counter += 1
218  self._internal_queue.put(chunk, block=True, timeout=0.5)
219  self._log_inputs_per_interval(chunk[0].shape[0])
220  return
221  except Queue.Full:
222  log.debug("Queue full: stalling fetchers...")
223  continue
224 
225  def _enqueue_batch(self):
226  '''
227  This pulls data from the python-side queue and collects them
228  into batch-sized pieces.
229  '''
230  cur_batch = [np.array([]) for d in self._input_blob_names]
231 
232  # Collect data until we have a full batch size
233  while cur_batch[0].shape[0] < self._batch_size and self.is_active():
234  chunk = self._get()
235  if chunk is None:
236  continue
237 
238  for j, chunk_elem in enumerate(chunk):
239  if cur_batch[j].shape[0] == 0:
240  cur_batch[j] = chunk_elem.copy()
241  else:
242  cur_batch[j] = np.append(cur_batch[j], chunk_elem, axis=0)
243 
244  start_time = time.time()
245  try:
246  # Return data over the batch size back to queue
247  if cur_batch[0].shape[0] > self._batch_size:
248  leftover = [c[self._batch_size:] for c in cur_batch]
249  cur_batch = [c[:self._batch_size] for c in cur_batch]
250  try:
251  self._internal_queue.put(leftover, block=False)
252  except Queue.Full:
253  pass
254 
255  assert cur_batch[0].shape[0] == self._batch_size
256 
257  if self.is_active():
258  for b, q, c in zip(self._input_blob_names, self._queues, cur_batch):
259  self._enqueue(b, q, c)
260  finally:
261  self.put_metric('enqueue_time', time.time() - start_time)
262 
263  def _enqueue(self, blob_name, queue, data_arr):
264  '''
265  Enqueue the correctly sized batch arrays to Caffe2's queue.
266  '''
267  scratch_name = self._namescope + blob_name + \
268  "_scratch_" + self._input_source_name
269  blob = core.BlobReference(scratch_name)
270  status = core.BlobReference(scratch_name + "_status")
272  blob,
273  data_arr,
274  device_option=self._device_option
275  )
276 
277  op = core.CreateOperator(
278  "SafeEnqueueBlobs",
279  [queue, blob],
280  [blob, status],
281  device_option=self._device_option
282  )
284 
285  def _create_caffe2_queues_and_ops(self):
286  '''
287  Creates queues on caffe2 side, and respective operators
288  to pull (dequeue) blobs from the queues.
289  '''
290  def create_queue(queue_name, num_blobs, capacity):
293  "CreateBlobsQueue",
294  [], [queue_name],
295  num_blobs=1,
296  capacity=capacity))
297  return core.ScopedBlobReference(queue_name)
298 
299  for blob_name in self._input_blob_names:
300  qname = blob_name + "_c2queue" + "_" + self._input_source_name
301  q = create_queue(qname, num_blobs=1, capacity=4)
302  self._queues.append(q)
303  print("Created queue: {}".format(q))
304 
305  # Add operator to the Caffe2 network to dequeue
306  self._net.DequeueBlobs(q, blob_name)
307 
308  def _log_inputs_per_interval(self, inputs, force=False):
309  self._inputs += inputs
310  current_seconds = time.time()
311  delta_seconds = current_seconds - self._prev_seconds
312  if delta_seconds >= LOG_INT_SECS or force:
313  inputs_per_sec = int(self._inputs / delta_seconds)
314  qsize = self._internal_queue.qsize()
315  print("{}/{}: {} inputs/sec".format(
316  self._input_source_name,
317  self._namescope,
318  inputs_per_sec,
319  ))
320  print("-- queue: {} batches".format(qsize))
321  # log and reset perf metrics
322  self.put_metric('inputs_per_sec', inputs_per_sec, False)
323  self.put_metric('queue_size', qsize, False)
324  self.put_metric('time_elapsed', delta_seconds, False)
325  self._log(self._metrics)
326  self._reset_metrics()
327  self._inputs = 0
328  self._prev_seconds = current_seconds
329 
330  def _log(self, metrics):
331  if not self._external_loggers:
332  return
333  for logger in self._external_loggers:
334  try:
335  logger.log(metrics)
336  except Exception as e:
337  print("Failed to call ExternalLogger: {}".format(e))
338 
339  def put_metric(self, key, value, count=True):
340  self._metrics[key] += value
341  if count:
342  count_key = '{}_count'.format(key)
343  self._metrics[count_key] += 1
344 
345  def _reset_metrics(self):
346  self._metrics = collections.defaultdict(lambda: 0)
347 
348 
349 class GlobalCoordinator(object):
350  def __init__(self):
351  self._coordinators = []
352  self._fetcher_id_seq = 0
353  self._worker_ids = []
354  self._queues = {}
356 
357  def add(self, coordinator):
358  self._coordinators.append(coordinator)
359 
360  def get_new_worker_id(self):
361  worker_id = self._fetcher_id_seq
362  self._worker_ids.append(worker_id)
363  self._fetcher_id_seq += 1
364  return worker_id
365 
366  def get_worker_ids(self):
367  return self._worker_ids
368 
369  def get_queue(self, queue_name, max_buffered_batches):
370  assert isinstance(max_buffered_batches, int)
371  if queue_name not in self._queues:
372  self._queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
373  return self._queues[queue_name]
374 
375  def start(self):
376  for c in self._coordinators:
377  c.init(self)
378  c._start()
379 
380  def stop(self):
381  all_success = True
382  for c in self._coordinators:
383  c._stop()
384  for c in self._coordinators:
385  success = c._wait_finish()
386  all_success = all_success and success
387  self._coordinators = []
388  return all_success
389 
390  def register_shutdown_handler(self):
391  def cleanup():
392  self.stop()
393 
394  atexit.register(cleanup)
395 
396 
397 global_coordinator = GlobalCoordinator()
398 
399 
400 def fetcher(coordinator, worker_id, fetch_fun, batch_size, input_blob_names):
401  while coordinator.is_active():
402  start_time = time.time()
403  try:
404  input_data = fetch_fun(worker_id, batch_size)
405  if input_data is None:
406  print("Fetcher function returned None")
407  continue
408 
409  assert len(input_data) == len(input_blob_names), \
410  "Expecting data blob for each input"
411  for d in input_data:
412  assert isinstance(d, np.ndarray), \
413  "Fetcher function must return a numpy array"
414  for d in input_data[1:]:
415  assert d.shape[0] == input_data[0].shape[0], \
416  "Each returned input must have equal number of samples"
417 
418  coordinator.put(input_data)
419  except Exception as e:
420  logging.exception("Exception in fetcher", e)
421  coordinator._stop("Exception in fetcher {}: {}".format(
422  worker_id, e
423  ))
424  finally:
425  coordinator.put_metric('fetcher_time', time.time() - start_time)
426 
427 
428 def enqueuer(coordinator):
429  while coordinator.is_active():
430  coordinator._enqueue_batch()
def put_metric(self, key, value, count=True)
def CurrentDeviceScope()
Definition: scope.py:33
def CurrentNameScope()
Definition: scope.py:26
def ScopedBlobReference(name, args, kwargs)
Definition: core.py:212
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
Definition: core.py:259
def _enqueue(self, blob_name, queue, data_arr)
def RunOperatorOnce(operator)
Definition: workspace.py:148
def _log_inputs_per_interval(self, inputs, force=False)