Caffe2 - Python API
A deep learning, cross platform ML framework
dataset.py
1 
3 """
4 Implementation of an in-memory dataset with structured schema.
5 
6 Use this to store and iterate through datasets with complex schema that
7 fit in memory.
8 
9 Iterating through entries of this dataset is very fast since the dataset
10 is stored as a set of native Caffe2 tensors, thus no type conversion or
11 deserialization is necessary.
12 """
13 from __future__ import absolute_import
14 from __future__ import division
15 from __future__ import print_function
16 from __future__ import unicode_literals
17 
18 from caffe2.python import core, workspace
19 from caffe2.python.dataio import Reader, Writer
20 from caffe2.python.schema import (
21  Struct, from_blob_list, Field, from_column_list, InitEmptyRecord)
22 import numpy as np
23 
24 
25 class _DatasetReader(Reader):
26  def __init__(self, dataset, name, batch_size=1):
27  """Don't call this directly. Instead, use dataset.reader()"""
28  Reader.__init__(self, dataset.content())
29  self.dataset = dataset
30  self.name = name or (dataset.name + '_cursor')
31  self.batch_size = batch_size
32  self.cursor = None
33 
34  def setup_ex(self, init_net, exit_net):
35  if self.cursor is None:
36  self.cursor = init_net.CreateTreeCursor(
37  [],
38  [self.name],
39  fields=self.dataset.fields)
40 
41  def read(self, read_net):
42  assert self.cursor, 'setup not called.'
43  content = self.dataset.content()
44  with core.NameScope(read_net.NextName(self.name)):
45  fields = read_net.ReadNextBatch(
46  [self.cursor] + content.field_blobs(),
47  content.field_names(),
48  batch_size=self.batch_size)
49  if type(fields) is core.BlobReference:
50  fields = [fields]
51  return (read_net.IsEmpty([fields[0]]), fields)
52 
53  def reset(self, net):
54  net.ResetCursor([self.cursor], [])
55 
56 
57 class _DatasetRandomReader(Reader):
58  def __init__(self, dataset, name, indices, batch_size=1):
59  """Don't call this directly. Instead, use dataset.random_reader()"""
60  Reader.__init__(self, dataset.content())
61  self.dataset = dataset
62  self.cursor = None
63  self.name = name or (dataset.name + '_cursor')
64  self.indices = indices
65  self.batch_size = batch_size
66 
67  def setup_ex(self, init_net, exit_net):
68  if self.cursor is None:
69  self.cursor = init_net.CreateTreeCursor(
70  [],
71  [self.name],
72  fields=self.dataset.fields)
73 
74  def reset(self, net):
75  net.ResetCursor([self.cursor], [])
76 
77  def computeoffset(self, net):
78  self.reset(net)
79  offsets = net.ComputeOffset(
80  [self.cursor] + self.dataset.content().field_blobs(),
81  'offsets')
82  self.offsets = offsets
83 
84  def sort_and_shuffle(self, net, sort_by_field=None,
85  shuffle_size=1, batch_size=1):
86  # no sorting by default
87  content = self.dataset.content()
88  sort_by_field_idx = -1
89  if sort_by_field:
90  assert sort_by_field in content.field_names(), (
91  'Must be valid field.')
92  sort_by_field_idx = content.field_names().index(sort_by_field)
93  self.reset(net)
94 
95  indices = net.SortAndShuffle(
96  [self.cursor] + content.field_blobs(),
97  'indices',
98  sort_by_field_idx=sort_by_field_idx,
99  shuffle_size=shuffle_size,
100  batch_size=batch_size)
101  self.indices = indices
102 
103  def read(self, read_net):
104  with core.NameScope(read_net.NextName(self.name)):
105  fields = read_net.ReadRandomBatch(
106  [self.cursor, self.indices, self.offsets] + (
107  self.dataset.content().field_blobs()),
108  self.dataset.content().field_names(),
109  batch_size=self.batch_size)
110  return (read_net.IsEmpty([fields[0]]), fields)
111 
112 
113 class _DatasetWriter(Writer):
114  def __init__(self, content):
115  """Don't call this directly. Use dataset.writer() instead."""
116  self._content = content
117  self.mutex = None
118 
119  def setup_ex(self, init_net, exit_net):
120  if self.mutex is None:
121  self.mutex = init_net.CreateMutex([])
122 
123  def write(self, writer_net, fields):
124  """
125  Add operations to `net` that append the blobs in `fields` to the end
126  of the dataset. An additional operator will also be added that checks
127  the consistency of the data in `fields` against the dataset schema.
128 
129  Args:
130  writer_net: The net that will contain the Append operators.
131  fields: A list of BlobReference to be appeneded to this dataset.
132  """
133  assert self.mutex is not None, 'setup not called.'
134  field_blobs = self._content.field_blobs()
135  assert len(fields) == len(field_blobs), (
136  'Expected %s fields, got %s.' % (len(field_blobs), len(fields)))
137  writer_net.CheckDatasetConsistency(
138  fields, [], fields=self._content.field_names())
139  writer_net.AtomicAppend(
140  [self.mutex] + field_blobs + list(fields),
141  field_blobs)
142 
143  def commit(self, finish_net):
144  """Commit is a no-op for an in-memory dataset."""
145  pass
146 
147 
148 def Const(net, value, dtype=None, name=None):
149  """
150  Create a 'constant' by first creating an external input in the given
151  net, and then feeding the corresponding blob with its provided value
152  in the current workspace. The name is automatically generated in order
153  to avoid clashes with existing blob names.
154  """
155  assert isinstance(net, core.Net), 'net must be a core.Net instance.'
156  value = np.array(value, dtype=dtype)
157  blob = net.AddExternalInput(net.NextName(prefix=name))
158  workspace.FeedBlob(str(blob), value)
159  return blob
160 
161 
162 def execution_step_with_progress(name, init_net, substeps, rows_read):
163  # progress reporter
164  report_net = core.Net('report_net')
165  report_net.Print([rows_read], [])
166  return core.execution_step(
167  name,
168  substeps,
169  report_net=report_net,
170  concurrent_substeps=True,
171  report_interval=5)
172 
173 
174 class Dataset(object):
175  """Represents an in-memory dataset with fixed schema.
176 
177  Use this to store and iterate through datasets with complex schema that
178  fit in memory.
179 
180  Iterating through entries of this dataset is very fast since the dataset
181  is stored as a set of native Caffe2 tensors, thus no type conversion or
182  deserialization is necessary.
183  """
184 
185  def __init__(self, fields, name=None):
186  """Create an un-initialized dataset with schema provided by `fields`.
187 
188  Before this dataset can be used, it must be initialized, either by
189  `init_empty` or `init_from_dataframe`.
190 
191  Args:
192  fields: either a schema.Struct or a list of field names in a format
193  compatible with the one described in schema.py.
194  name: optional name to prepend to blobs that will store the data.
195  """
196  assert isinstance(fields, list) or isinstance(fields, Struct), (
197  'fields must be either a Struct or a list of raw field names.')
198  if isinstance(fields, list):
199  fields = from_column_list(fields)
200  self.schema = fields
201  self.fields = fields.field_names()
202  self.field_types = fields.field_types()
203  self.name = name or 'dataset'
204  self.field_blobs = fields.field_blobs() if fields.has_blobs() else None
205 
206  def init_empty(self, init_net):
207  """Initialize the blobs for this dataset with empty values.
208 
209  Empty arrays will be immediately fed into the current workspace,
210  and `init_net` will take those blobs as external inputs.
211  """
212  self.field_blobs = InitEmptyRecord(
213  init_net, self.schema.clone_schema()).field_blobs()
214 
215  def init_from_dataframe(self, net, dataframe):
216  """Initialize the blobs for this dataset from a Pandas dataframe.
217 
218  Each column of the dataframe will be immediately fed into the current
219  workspace, and the `net` will take this blobs as external inputs.
220  """
221  assert len(self.fields) == len(dataframe.columns)
222  self.field_blobs = [
223  Const(net, dataframe.as_matrix([col]).flatten(), name=field)
224  for col, field in enumerate(self.fields)]
225 
226  def get_blobs(self):
227  """
228  Return the list of BlobReference pointing to the blobs that contain
229  the data for this dataset.
230  """
231  assert self
232  return self.field_blobs
233 
234  def content(self):
235  """
236  Return a Record of BlobReferences pointing to the full content of
237  this dataset.
238  """
239  return from_blob_list(self.schema, self.field_blobs)
240 
241  def field_names(self):
242  """Return the list of field names for this dataset."""
243  return self.fields
244 
245  def field_types(self):
246  """
247  Return the list of field dtypes for this dataset.
248 
249  If a list of strings, not a schema.Struct, was passed to the
250  constructor, this will return a list of dtype(np.void).
251  """
252  return self.field_types
253 
254  def reader(self, init_net=None, cursor_name=None, batch_size=1):
255  """Create a Reader object that is used to iterate through the dataset.
256 
257  This will append operations to `init_net` that create a TreeCursor,
258  used to iterate through the data.
259 
260  NOTE: Currently, it is not safe to append to a dataset while reading.
261 
262  Args:
263  init_net: net that will be run once to create the cursor.
264  cursor_name: optional name for the blob containing a pointer
265  to the cursor.
266  batch_size: how many samples to read per iteration.
267 
268  Returns:
269  A _DatasetReader that can be used to create operators that will
270  iterate through the dataset.
271  """
272  assert self.field_blobs, 'Dataset not initialized.'
273  reader = _DatasetReader(self, cursor_name, batch_size)
274  if init_net is not None:
275  reader.setup_ex(init_net, None)
276  return reader
277 
278  def random_reader(self, init_net=None, indices=None, cursor_name=None,
279  batch_size=1):
280  """Create a Reader object that is used to iterate through the dataset.
281 
282  NOTE: The reader order depends on the order in indices.
283 
284  Args:
285  init_net: net that will be run once to create the cursor.
286  indices: blob of reading order
287  cursor_name: optional name for the blob containing a pointer
288  to the cursor.
289  batch_size: how many samples to read per iteration.
290 
291  Returns:
292  A DatasetReader that can be used to create operators that will
293  iterate through the dataset according to indices.
294  """
295  assert self.field_blobs, 'Dataset not initialized.'
296  reader = _DatasetRandomReader(self, cursor_name, indices, batch_size)
297  if init_net is not None:
298  reader.setup_ex(init_net, None)
299  return reader
300 
301  def writer(self, init_net=None):
302  """Create a Writer that can be used to append entries into the dataset.
303 
304  NOTE: Currently, it is not safe to append to a dataset
305  while reading from it.
306  NOTE: Currently implementation of writer is not thread safe.
307  TODO: fixme
308 
309  Args:
310  init_net: net that will be run once in order to create the writer.
311  (currently not used)
312  """
313  assert self.field_blobs, 'Dataset not initialized.'
314  writer = _DatasetWriter(self.content())
315  if init_net is not None:
316  writer.setup_ex(init_net, None)
317  return writer
def random_reader(self, init_net=None, indices=None, cursor_name=None, batch_size=1)
Definition: dataset.py:279
def execution_step(default_name, steps_or_nets, num_iter=None, report_net=None, report_interval=None, concurrent_substeps=None, should_stop_blob=None, only_once=None)
Definition: core.py:2018
def init_empty(self, init_net)
Definition: dataset.py:206
def Const(net, value, dtype=None, name=None)
Definition: dataset.py:148
def init_from_dataframe(self, net, dataframe)
Definition: dataset.py:215
def __init__(self, fields, name=None)
Definition: dataset.py:185
def write(self, writer_net, fields)
Definition: dataset.py:123
def commit(self, finish_net)
Definition: dataset.py:143
def __init__(self, dataset, name, indices, batch_size=1)
Definition: dataset.py:58
def writer(self, init_net=None)
Definition: dataset.py:301
def content(self)
Definition: dataset.py:234
NameScope
Definition: core.py:28
def __init__(self, content)
Definition: dataset.py:114
def get_blobs(self)
Definition: dataset.py:226
def reader(self, init_net=None, cursor_name=None, batch_size=1)
Definition: dataset.py:254
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def reset(self, net)
Definition: dataset.py:74
def field_names(self)
Definition: dataset.py:241
def __init__(self, dataset, name, batch_size=1)
Definition: dataset.py:26