3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 This module provides a python-land multithreaded data input mechanism 13 Basic usage is as follows: 14 coordinator = data_workers.init_data_input_workers( 19 input_source_name="train" 24 First argument is the Caffe2 net (or model helper), and second argument 25 is list of input blobs that are to be fed. 27 Argument 'input_source_name' is used to distinguish different sources of data, 28 such as train or test data. This is to ensure the data does not get mixed up, 29 although two nets would share blobs. 31 To do the actual data loading, one defines a "fetcher function" 32 that has call signature 33 my_fetch_fun(worker_id, batch_size) 35 Optionally, one can define a "init function" that is called once before 36 threads start, and has call signature: 37 my_init_fun(data_coordinator, global_coordinator) 39 This function returns a list of numpy arrays corresponding to the different 40 input blobs. In the example above, it would return two arrays, one for the 41 data blob and another for the labels. These arrays can have arbitrary number 42 of elements (i.e they do not need to match the batch size). The batch size 43 is provided for the function as a hint only. 45 For example, fetcher function could download images from a remote service or 46 load random images from a directory on a file system. 48 For a dummy example, see the data_workers_test unit test. 50 Note that for data_parallel_models, init_data_input_workers will be called 51 for each GPU. Note that the 'coordinator' returned by the function is same 63 from caffe2.python
import workspace, core, scope
64 from caffe2.proto
import caffe2_pb2
66 log = logging.getLogger(
"data_workers")
67 log.setLevel(logging.INFO)
71 def get_worker_ids(num_workers):
72 return range(0, num_workers)
75 def init_data_input_workers(
81 input_source_name="train",
82 max_buffered_batches=800,
84 external_loggers=None,
86 global global_coordinator
88 if (device_option
is None):
89 device_option = caffe2_pb2.DeviceOption(device_type=caffe2_pb2.CPU)
99 global_coordinator.get_queue(input_source_name, max_buffered_batches),
101 external_loggers=external_loggers,
106 global_coordinator.get_new_worker_id()
107 for i
in range(num_worker_threads)
112 name=
"data_workers fetcher id {}".format(worker_id),
113 args=[coordinator, worker_id, fetch_fun, batch_size, input_blob_names],
114 )
for worker_id
in worker_ids
117 workers.append(threading.Thread(
121 coordinator._workers = workers
122 global_coordinator.add(coordinator)
124 return global_coordinator
128 def __init__(self, net, input_blob_names, batch_size,
129 device_option, namescope, input_source_name, queue,
130 init_fun=None, external_loggers=None):
148 self.
_metrics = collections.defaultdict(
lambda: 0)
154 def init(self, global_coordinator):
170 def _stop(self, reason=None):
173 if reason
is not None:
174 log.error(
"Data input failed due to an error: {}".format(reason))
184 def _wait_finish(self):
185 print(
"Wait for workers to die")
187 if w != threading.current_thread():
192 print(
"Worker {} failed to close while waiting".format(w))
195 print(
"All workers terminated: {}".format(success))
206 def put(self, chunk):
208 print(
"Worker provided zero length input")
213 if qsize < 2
and (time.time() - self.
_last_warning) > LOG_INT_SECS:
214 print(
"Warning, data loading lagging behind: " +
222 log.debug(
"Queue full: stalling fetchers...")
225 def _enqueue_batch(self):
227 This pulls data from the python-side queue and collects them 228 into batch-sized pieces. 238 for j, chunk_elem
in enumerate(chunk):
239 if cur_batch[j].shape[0] == 0:
240 cur_batch[j] = chunk_elem.copy()
242 cur_batch[j] = np.append(cur_batch[j], chunk_elem, axis=0)
244 start_time = time.time()
248 leftover = [c[self.
_batch_size:]
for c
in cur_batch]
249 cur_batch = [c[:self.
_batch_size]
for c
in cur_batch]
261 self.
put_metric(
'enqueue_time', time.time() - start_time)
263 def _enqueue(self, blob_name, queue, data_arr):
265 Enqueue the correctly sized batch arrays to Caffe2's queue. 267 scratch_name = self.
_namescope + blob_name + \
285 def _create_caffe2_queues_and_ops(self):
287 Creates queues on caffe2 side, and respective operators 288 to pull (dequeue) blobs from the queues. 290 def create_queue(queue_name, num_blobs, capacity):
301 q = create_queue(qname, num_blobs=1, capacity=4)
303 print(
"Created queue: {}".format(q))
306 self.
_net.DequeueBlobs(q, blob_name)
308 def _log_inputs_per_interval(self, inputs, force=False):
310 current_seconds = time.time()
312 if delta_seconds >= LOG_INT_SECS
or force:
313 inputs_per_sec = int(self.
_inputs / delta_seconds)
315 print(
"{}/{}: {} inputs/sec".format(
320 print(
"-- queue: {} batches".format(qsize))
322 self.
put_metric(
'inputs_per_sec', inputs_per_sec,
False)
324 self.
put_metric(
'time_elapsed', delta_seconds,
False)
330 def _log(self, metrics):
336 except Exception
as e:
337 print(
"Failed to call ExternalLogger: {}".format(e))
339 def put_metric(self, key, value, count=True):
342 count_key =
'{}_count'.format(key)
345 def _reset_metrics(self):
346 self.
_metrics = collections.defaultdict(
lambda: 0)
357 def add(self, coordinator):
360 def get_new_worker_id(self):
366 def get_worker_ids(self):
369 def get_queue(self, queue_name, max_buffered_batches):
370 assert isinstance(max_buffered_batches, int)
371 if queue_name
not in self.
_queues:
372 self.
_queues[queue_name] = Queue.Queue(maxsize=max_buffered_batches)
373 return self.
_queues[queue_name]
385 success = c._wait_finish()
386 all_success = all_success
and success
390 def register_shutdown_handler(self):
394 atexit.register(cleanup)
400 def fetcher(coordinator, worker_id, fetch_fun, batch_size, input_blob_names):
401 while coordinator.is_active():
402 start_time = time.time()
404 input_data = fetch_fun(worker_id, batch_size)
405 if input_data
is None:
406 print(
"Fetcher function returned None")
409 assert len(input_data) == len(input_blob_names), \
410 "Expecting data blob for each input" 412 assert isinstance(d, np.ndarray), \
413 "Fetcher function must return a numpy array" 414 for d
in input_data[1:]:
415 assert d.shape[0] == input_data[0].shape[0], \
416 "Each returned input must have equal number of samples" 418 coordinator.put(input_data)
419 except Exception
as e:
420 logging.exception(
"Exception in fetcher", e)
421 coordinator._stop(
"Exception in fetcher {}: {}".format(
425 coordinator.put_metric(
'fetcher_time', time.time() - start_time)
428 def enqueuer(coordinator):
429 while coordinator.is_active():
430 coordinator._enqueue_batch()
def ScopedBlobReference(name, args, kwargs)
def register_shutdown_handler(self)
def FeedBlob(name, arr, device_option=None)
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
def RunOperatorOnce(operator)