Caffe2 - Python API
A deep learning, cross platform ML framework
schema.py
1 
3 """
4 Defines a minimal set of data types that allow to represent datasets with
5 arbitrary nested structure, including objects of variable length, such as
6 maps and lists.
7 
8 This defines a columnar storage format for such datasets on top of caffe2
9 tensors. In terms of capacity of representation, it can represent most of
10 the data types supported by Parquet, ORC, DWRF file formats.
11 
12 See comments in operator_test/dataset_ops_test.py for a example and
13 walkthrough on how to use schema to store and iterate through a structured
14 in-memory dataset.
15 """
16 from __future__ import absolute_import
17 from __future__ import division
18 from __future__ import print_function
19 from __future__ import unicode_literals
20 
21 import logging
22 import numpy as np
23 from caffe2.python import core
24 from caffe2.python import workspace
25 from caffe2.python.core import BlobReference
26 from collections import OrderedDict, namedtuple
27 
28 logger = logging.getLogger(__name__)
29 logger.setLevel(logging.INFO)
30 
31 FIELD_SEPARATOR = ':'
32 
33 
34 def _join_field_name(prefix, suffix):
35  if prefix and suffix:
36  return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
37  elif prefix:
38  return prefix
39  elif suffix:
40  return suffix
41  else:
42  return ''
43 
44 
45 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
46  """Clones/normalizes a field before adding it to a container."""
47  if isinstance(field_or_type_or_blob, Field):
48  return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
49  elif type(field_or_type_or_blob) in (type, np.dtype):
50  return Scalar(dtype=field_or_type_or_blob)
51  else:
52  return Scalar(blob=field_or_type_or_blob)
53 
54 
55 FeatureSpec = namedtuple(
56  'FeatureSpec',
57  ['feature_type', 'feature_names', 'feature_ids', 'feature_is_request_only']
58 )
59 FeatureSpec.__new__.__defaults__ = (None, None, None, None)
60 
61 
62 class Metadata( namedtuple(
63  'Metadata', ['categorical_limit', 'expected_value', 'feature_specs']
64  )
65 ):
66  """Represents additional information associated with a scalar in schema.
67 
68  `categorical_limit` - for fields of integral type that are guaranteed to be
69  non-negative it specifies the maximum possible value plus one. It's often
70  used as a size of an embedding table.
71 
72  `expected_value` - anticipated average value of elements in the field.
73  Usually makes sense for length fields of lists.
74 
75  `feature_specs` - information about the features that contained in this
76  field. For example if field have more then 1 feature it can have list of
77  feature names contained in this field."""
78  __slots__ = ()
79 
80 
81 Metadata.__new__.__defaults__ = (None, None, None, None)
82 
83 
84 class Field(object):
85  """Represents an abstract field type in a dataset.
86  """
87 
88  def __init__(self, children):
89  """Derived classes must call this after their initialization."""
90  self._parent = (None, 0)
91  offset = 0
92  self._field_offsets = []
93  for child in children:
94  self._field_offsets.append(offset)
95  offset += len(child.field_names())
96  self._field_offsets.append(offset)
97 
98  def clone_schema(self):
99  return self.clone(keep_blobs=False)
100 
101  def field_names(self):
102  """Return the children field names for this field."""
103  raise NotImplementedError('Field is an abstract class.')
104 
105  def field_types(self):
106  """Return the numpy.dtype for each of the children fields."""
107  raise NotImplementedError('Field is an abstract class.')
108 
109  def field_metadata(self):
110  """Return the Metadata for each of the children fields."""
111  raise NotImplementedError('Field is an abstract class.')
112 
113  def field_blobs(self):
114  """Return the list of blobs with contents for this Field.
115  Values can either be all numpy.ndarray or BlobReference.
116  If any of the fields doens't have a blob, throws.
117  """
118  raise NotImplementedError('Field is an abstract class.')
119 
120  def all_scalars(self):
121  """Return the list of all Scalar instances in the Field.
122  The order is the same as for field_names() or field_blobs()"""
123  raise NotImplementedError('Field is an abstract class.')
124 
125  def has_blobs(self):
126  """Return True if every scalar of this field has blobs."""
127  raise NotImplementedError('Field is an abstract class.')
128 
129  def clone(self, keep_blobs=True):
130  """Clone this Field along with its children."""
131  raise NotImplementedError('Field is an abstract class.')
132 
133  def _set_parent(self, parent, relative_id):
134  self._parent = (parent, relative_id)
135 
136  def slice(self):
137  """
138  Returns a slice representing the range of field ids that belong to
139  this field. This slice can be used to index a list of fields.
140 
141  E.g.:
142 
143  >>> s = Struct(
144  >>> ('a', Scalar()),
145  >>> ('b', Struct(
146  >>> ('b1', Scalar()),
147  >>> ('b2', Scalar()),
148  >>> )),
149  >>> ('c', Scalar()),
150  >>> )
151  >>> field_data = ['da', 'db1', 'db2', 'dc']
152  >>> field_data[s.b.split()]
153  ['db1', 'db2']
154  """
155  base_id = self._child_base_id()
156  return slice(base_id, base_id + len(self.field_names()))
157 
158  def _child_base_id(self, child_index=None):
159  """Get the base id of the given child"""
160  p, i = self._parent
161  pos = 0 if child_index is None else self._field_offsets[child_index]
162  if p:
163  pos += p._child_base_id(i)
164  return pos
165 
166  def __eq__(self, other):
167  """Equivalance of two schemas"""
168  return (
169  (self.field_names() == other.field_names()) and
170  (self.field_types() == other.field_types()) and
171  (self.field_metadata() == other.field_metadata())
172  )
173 
174 
175 class List(Field):
176  """Represents a variable-length list.
177 
178  Values of a list can also be complex fields such as Lists and Structs.
179  In addition to the fields exposed by its `values` field, a List exposes an
180  additional `lengths` field, which will contain the size of each list under
181  the parent domain.
182  """
183 
184  def __init__(self, values, lengths_blob=None):
185  if isinstance(lengths_blob, Field):
186  assert isinstance(lengths_blob, Scalar)
187  self.lengths = _normalize_field(lengths_blob)
188  else:
189  self.lengths = Scalar(np.int32, lengths_blob)
190  self._items = _normalize_field(values)
191  self.lengths._set_parent(self, 0)
192  self._items._set_parent(self, 1)
193  Field.__init__(self, [self.lengths, self._items])
194 
195  def field_names(self):
196  value_fields = self._items.field_names()
197  return (
198  ['lengths'] + [_join_field_name('values', v) for v in value_fields]
199  )
200 
201  def field_types(self):
202  return self.lengths.field_types() + self._items.field_types()
203 
204  def field_metadata(self):
205  return self.lengths.field_metadata() + self._items.field_metadata()
206 
207  def field_blobs(self):
208  return self.lengths.field_blobs() + self._items.field_blobs()
209 
210  def all_scalars(self):
211  return self.lengths.all_scalars() + self._items.all_scalars()
212 
213  def has_blobs(self):
214  return self.lengths.has_blobs() and self._items.has_blobs()
215 
216  def clone(self, keep_blobs=True):
217  return List(
218  _normalize_field(self._items, keep_blobs=keep_blobs),
219  _normalize_field(self.lengths, keep_blobs=keep_blobs)
220  )
221 
222  def __getattr__(self, item):
223  """If the value of this list is a struct,
224  allow to instrospect directly into its fields."""
225  if item.startswith('__'):
226  raise AttributeError(item)
227  if isinstance(self._items, Struct):
228  return getattr(self._items, item)
229  elif item == 'value' or item == 'items':
230  return self._items
231  else:
232  raise AttributeError('Field not found in list: %s.' % item)
233 
234  def __getitem__(self, item):
235  if isinstance(self._items, Struct):
236  return self._items[item]
237  elif item == 'lengths':
238  return self.lengths
239  elif item == 'value' or item == 'items':
240  return self._items
241  else:
242  raise KeyError('Field not found in list: %s.' % item)
243 
244 
245 class Struct(Field):
246  """Represents a named list of fields sharing the same domain.
247  """
248 
249  def __init__(self, *fields):
250  """ fields is a list of tuples in format of (name, field). The name is
251  a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example
252 
253  Struct(
254  ('a', Scalar()),
255  ('b:c', Scalar()),
256  ('b:d:e', Scalar()),
257  ('b', Struct(
258  ('f', Scalar()),
259  )),
260  )
261 
262  is equal to
263 
264  Struct(
265  ('a', Scalar()),
266  ('b', Struct(
267  ('c', Scalar()),
268  ('d', Struct(('e', Scalar()))),
269  ('f', Scalar()),
270  )),
271  )
272  """
273  for field in fields:
274  assert len(field) == 2
275  assert field[0], 'Field names cannot be empty'
276  assert field[0] != 'lengths', (
277  'Struct cannot contain a field named `lengths`.'
278  )
279  fields = [(name, _normalize_field(field)) for name, field in fields]
280  self.fields = OrderedDict()
281  for name, field in fields:
282  if FIELD_SEPARATOR in name:
283  name, field = self._struct_from_nested_name(name, field)
284  if name not in self.fields:
285  self.fields[name] = field
286  continue
287  if (
288  not isinstance(field, Struct) or
289  not isinstance(self.fields[name], Struct)
290  ):
291  raise ValueError('Duplicate field name: %s' % name)
292  self.fields[name] = self.fields[name] + field
293  for id, (_, field) in enumerate(self.fields.items()):
294  field._set_parent(self, id)
295  Field.__init__(self, self.fields.values())
296 
297  def _struct_from_nested_name(self, nested_name, field):
298  def create_internal(nested_name, field):
299  names = nested_name.split(FIELD_SEPARATOR, 1)
300  if len(names) == 1:
301  added_field = field
302  else:
303  added_field = create_internal(names[1], field)
304  return Struct((names[0], added_field))
305 
306  names = nested_name.split(FIELD_SEPARATOR, 1)
307  assert len(names) >= 2
308  return names[0], create_internal(names[1], field)
309 
310  def get_children(self):
311  return self.fields.items()
312 
313  def field_names(self):
314  names = []
315  for name, field in self.fields.items():
316  names += [_join_field_name(name, f) for f in field.field_names()]
317  return names
318 
319  def field_types(self):
320  types = []
321  for _, field in self.fields.items():
322  types += field.field_types()
323  return types
324 
325  def field_metadata(self):
326  metadata = []
327  for _, field in self.fields.items():
328  metadata += field.field_metadata()
329  return metadata
330 
331  def field_blobs(self):
332  blobs = []
333  for _, field in self.fields.items():
334  blobs += field.field_blobs()
335  return blobs
336 
337  def all_scalars(self):
338  scalars = []
339  for _, field in self.fields.items():
340  scalars += field.all_scalars()
341  return scalars
342 
343  def has_blobs(self):
344  return all(field.has_blobs() for field in self.fields.values())
345 
346  def clone(self, keep_blobs=True):
347  normalized_fields = [
348  (k, _normalize_field(v, keep_blobs=keep_blobs))
349  for k, v in self.fields.items()
350  ]
351  return Struct(*normalized_fields)
352 
353  def _get_field_by_nested_name(self, nested_name):
354  names = nested_name.split(FIELD_SEPARATOR, 1)
355  field = self.fields.get(names[0], None)
356 
357  if field is None:
358  return None
359 
360  if len(names) == 1:
361  return field
362 
363  try:
364  return field[names[1]]
365  except (KeyError, TypeError):
366  return None
367 
368  def __contains__(self, item):
369  field = self._get_field_by_nested_name(item)
370  return field is not None
371 
372  def __len__(self):
373  return len(self.fields)
374 
375  def __getitem__(self, item):
376  """
377  item can be a tuple or list of ints or strings, or a single
378  int or string. String item is a nested field name, e.g., "a", "a:b",
379  "a:b:c". Int item is the index of a field at the first level of the
380  Struct.
381  """
382  if isinstance(item, list) or isinstance(item, tuple):
383  return Struct(
384  * [
385  (
386  self.fields.keys()[k]
387  if isinstance(k, int) else k, self[k]
388  ) for k in item
389  ]
390  )
391  elif isinstance(item, int):
392  return self.fields.values()[item]
393  else:
394  field = self._get_field_by_nested_name(item)
395  if not field:
396  raise KeyError('field "%s" not found' % (item))
397  return field
398 
399  def __getattr__(self, item):
400  if item.startswith('__'):
401  raise AttributeError(item)
402  try:
403  return self.__dict__['fields'][item]
404  except KeyError:
405  raise AttributeError(item)
406 
407  def __add__(self, other):
408  """
409  Allows to merge fields of two schema.Struct using '+' operator.
410  If two Struct have common field names, the merge is conducted
411  recursively. Here are examples:
412 
413  Example 1
414  s1 = Struct(('a', Scalar()))
415  s2 = Struct(('b', Scalar()))
416  s1 + s2 == Struct(
417  ('a', Scalar()),
418  ('b', Scalar()),
419  )
420 
421  Example 2
422  s1 = Struct(
423  ('a', Scalar()),
424  ('b', Struct(('c', Scalar()))),
425  )
426  s2 = Struct(('b', Struct(('d', Scalar()))))
427  s1 + s2 == Struct(
428  ('a', Scalar()),
429  ('b', Struct(
430  ('c', Scalar()),
431  ('d', Scalar()),
432  )),
433  )
434  """
435  if not isinstance(other, Struct):
436  return NotImplemented
437 
438  children = OrderedDict(self.get_children())
439  for name, right_field in other.get_children():
440  if name not in children:
441  children[name] = right_field
442  continue
443  left_field = children[name]
444  children[name] = left_field + right_field
445 
446  return Struct(*(children.items()))
447 
448 
449 class Scalar(Field):
450  """Represents a typed scalar or tensor of fixed shape.
451 
452  A Scalar is a leaf in a schema tree, translating to exactly one tensor in
453  the dataset's underlying storage.
454 
455  Usually, the tensor storing the actual values of this field is a 1D tensor,
456  representing a series of values in its domain. It is possible however to
457  have higher rank values stored as a Scalar, as long as all entries have
458  the same shape.
459 
460  E.g.:
461 
462  Scalar(np.float64)
463 
464  Scalar field of type float32. Caffe2 will expect readers and
465  datasets to expose it as a 1D tensor of doubles (vector), where
466  the size of the vector is determined by this fields' domain.
467 
468  Scalar((np.int32, 5))
469 
470  Tensor field of type int32. Caffe2 will expect readers and
471  datasets to implement it as a 2D tensor (matrix) of shape (L, 5),
472  where L is determined by this fields' domain.
473 
474  Scalar((str, (10, 20)))
475 
476  Tensor field of type str. Caffe2 will expect readers and
477  datasets to implement it as a 3D tensor of shape (L, 10, 20),
478  where L is determined by this fields' domain.
479 
480  If the field type is unknown at construction time, call Scalar(), that will
481  default to np.void as its dtype.
482 
483  It is an error to pass a structured dtype to Scalar, since it would contain
484  more than one field. Instead, use from_dtype, which will construct
485  a nested `Struct` field reflecting the given dtype's structure.
486 
487  A Scalar can also contain a blob, which represents the value of this
488  Scalar. A blob can be either a numpy.ndarray, in which case it contain the
489  actual contents of the Scalar, or a BlobReference, which represents a
490  blob living in a caffe2 Workspace. If blob of different types are passed,
491  a conversion to numpy.ndarray is attempted.
492  """
493 
494  def __init__(self, dtype=None, blob=None, metadata=None):
495  self._metadata = None
496  self.set(dtype, blob, metadata)
497  Field.__init__(self, [])
498 
499  def field_names(self):
500  return ['']
501 
502  def field_type(self):
503  return self.dtype
504 
505  def field_types(self):
506  return [self.dtype]
507 
508  def field_metadata(self):
509  return [self._metadata]
510 
511  def has_blobs(self):
512  return self._blob is not None
513 
514  def field_blobs(self):
515  assert self._blob is not None, 'Value is not set for this field.'
516  return [self._blob]
517 
518  def all_scalars(self):
519  return [self]
520 
521  def clone(self, keep_blobs=True):
522  return Scalar(
523  dtype=self._original_dtype,
524  blob=self._blob if keep_blobs else None,
525  metadata=self._metadata
526  )
527 
528  def get(self):
529  """Gets the current blob of this Scalar field."""
530  assert self._blob is not None, 'Value is not set for this field.'
531  return self._blob
532 
533  def __call__(self):
534  """Shortcut for self.get()"""
535  return self.get()
536 
537  @property
538  def metadata(self):
539  return self._metadata
540 
541  def set_metadata(self, value):
542  assert isinstance(value, Metadata), \
543  'metadata must be Metadata, got {}'.format(type(value))
544  self._metadata = value
545  self._validate_metadata()
546 
547  def _validate_metadata(self):
548  if self._metadata is None:
549  return
550  if (self._metadata.categorical_limit is not None and
551  self.dtype is not None):
552  assert np.issubdtype(self.dtype, np.integer), \
553  "`categorical_limit` can be specified only in integral " + \
554  "fields but got {}".format(self.dtype)
555 
556  def set_value(self, blob):
557  """Sets only the blob field still validating the existing dtype"""
558  self.set(dtype=self._original_dtype, blob=blob)
559 
560  def set(self, dtype=None, blob=None, metadata=None):
561  """Set the type and/or blob of this scalar. See __init__ for details.
562 
563  Args:
564  dtype: can be any numpy type. If not provided and `blob` is
565  provided, it will be inferred. If no argument is provided,
566  this Scalar will be of type np.void.
567  blob: if provided, can be either a BlobReference or a
568  numpy.ndarray. If a value of different type is passed,
569  a conversion to numpy.ndarray is attempted. Strings aren't
570  accepted, since they can be ambiguous. If you want to pass
571  a string, to either BlobReference(blob) or np.array(blob).
572  metadata: optional instance of Metadata, if provided overrides
573  the metadata information of the scalar
574  """
575  if blob is not None and isinstance(blob, core.basestring):
576  raise ValueError(
577  'Passing str blob to Scalar.set() is ambiguous. '
578  'Do either set(blob=np.array(blob)) or '
579  'set(blob=BlobReference(blob))'
580  )
581 
582  self._original_dtype = dtype
583  if dtype is not None:
584  dtype = np.dtype(dtype)
585  # If blob is not None and it is not a BlobReference, we assume that
586  # it is actual tensor data, so we will try to cast it to an numpy array.
587  if blob is not None and not isinstance(blob, BlobReference):
588  if dtype is not None and dtype != np.void:
589  blob = np.array(blob, dtype=dtype.base)
590  # if array is empty we may need to reshape a little
591  if blob.size == 0:
592  blob = blob.reshape((0, ) + dtype.shape)
593  else:
594  assert isinstance(blob, np.ndarray), (
595  'Invalid blob type: %s' % str(type(blob)))
596 
597  # reshape scalars into 1D arrays
598  # TODO(azzolini): figure out better way of representing this
599  if len(blob.shape) == 0:
600  blob = blob.reshape((1, ))
601 
602  # infer inner shape from the blob given
603  # TODO(dzhulgakov): tweak this to make it work with PackedStruct
604  if (len(blob.shape) > 1 and dtype is not None and
605  dtype.base != np.void):
606  dtype = np.dtype((dtype.base, blob.shape[1:]))
607  # if we were still unable to infer the dtype
608  if dtype is None:
609  dtype = np.dtype(np.void)
610  assert not dtype.fields, (
611  'Cannot create Scalar with a structured dtype. ' +
612  'Use from_dtype instead.'
613  )
614  self.dtype = dtype
615  self._blob = blob
616  if metadata is not None:
617  self.set_metadata(metadata)
618  self._validate_metadata()
619 
620  def set_type(self, dtype):
621  self._original_dtype = dtype
622  if dtype is not None:
623  self.dtype = np.dtype(dtype)
624  else:
625  self.dtype = np.dtype(np.void)
626  self._validate_metadata()
627 
628  def id(self):
629  """
630  Return the zero-indexed position of this scalar field in its schema.
631  Used in order to index into the field_blob list returned by readers or
632  accepted by writers.
633  """
634  return self._child_base_id()
635 
636 
637 def Map(
638  keys,
639  values,
640  keys_name='keys',
641  values_name='values',
642  lengths_blob=None
643 ):
644  """A map is a List of Struct containing keys and values fields.
645  Optionally, you can provide custom name for the key and value fields.
646  """
647  return List(
648  Struct((keys_name, keys), (values_name, values)),
649  lengths_blob=lengths_blob
650  )
651 
652 
653 def NamedTuple(name_prefix, *fields):
654  return Struct(* [('%s_%d' % (name_prefix, i), field)
655  for i, field in enumerate(fields)])
656 
657 
658 def Tuple(*fields):
659  """
660  Creates a Struct with default, sequential, field names of given types.
661  """
662  return NamedTuple('field', *fields)
663 
664 
665 def RawTuple(num_fields, name_prefix='field'):
666  """
667  Creates a tuple of `num_field` untyped scalars.
668  """
669  assert isinstance(num_fields, int)
670  assert num_fields >= 0
671  return NamedTuple(name_prefix, *([np.void] * num_fields))
672 
673 
674 def from_dtype(dtype, _outer_shape=()):
675  """Constructs a Caffe2 schema from the given numpy's dtype.
676 
677  Numpy supports scalar, array-like and structured datatypes, as long as
678  all the shapes are fixed. This function breaks down the given dtype into
679  a Caffe2 schema containing `Struct` and `Scalar` types.
680 
681  Fields containing byte offsets are not currently supported.
682  """
683  if not isinstance(dtype, np.dtype):
684  # wrap into a ndtype
685  shape = _outer_shape
686  dtype = np.dtype((dtype, _outer_shape))
687  else:
688  # concatenate shapes if necessary
689  shape = _outer_shape + dtype.shape
690  if shape != dtype.shape:
691  dtype = np.dtype((dtype.base, shape))
692 
693  if not dtype.fields:
694  return Scalar(dtype)
695 
696  struct_fields = []
697  for name, (fdtype, offset) in dtype.fields:
698  assert offset == 0, ('Fields with byte offsets are not supported.')
699  struct_fields += (name, from_dtype(fdtype, _outer_shape=shape))
700  return Struct(*struct_fields)
701 
702 
703 class _SchemaNode(object):
704  """This is a private class used to represent a Schema Node"""
705 
706  def __init__(self, name, type_str=''):
707  self.name = name
708  self.children = []
709  self.type_str = type_str
710  self.field = None
711  self.col_blob = None
712 
713  def add_child(self, name, type_str=''):
714  for child in self.children:
715  if child.name == name and child.type_str == type_str:
716  return child
717  child = _SchemaNode(name, type_str)
718  self.children.append(child)
719  return child
720 
721  def get_field(self):
722 
723  list_names = ['lengths', 'values']
724  map_names = ['lengths', 'keys', 'values']
725 
726  if len(self.children) == 0 or self.field is not None:
727  assert self.field is not None
728  return self.field
729 
730  child_names = []
731  for child in self.children:
732  child_names.append(child.name)
733 
734  if (set(child_names) == set(list_names)):
735  for child in self.children:
736  if child.name == 'values':
737  self.field = List(
738  child.get_field(),
739  lengths_blob=self.children[0].col_blob
740  )
741  self.type_str = "List"
742  return self.field
743  elif (set(child_names) == set(map_names)):
744  for child in self.children:
745  if child.name == 'keys':
746  key_field = child.get_field()
747  elif child.name == 'values':
748  values_field = child.get_field()
749  self.field = Map(
750  key_field,
751  values_field,
752  lengths_blob=self.children[0].col_blob
753  )
754  self.type_str = "Map"
755  return self.field
756 
757  else:
758  struct_fields = []
759  for child in self.children:
760  if child.field is not None:
761  struct_fields.append((child.name, child.field))
762  else:
763  struct_fields.append((child.name, child.get_field()))
764 
765  self.field = Struct(*struct_fields)
766  self.type_str = "Struct"
767  return self.field
768 
769  def print_recursively(self):
770  for child in self.children:
771  child.print_recursively()
772  logger.info("Printing node: Name and type")
773  logger.info(self.name)
774  logger.info(self.type_str)
775 
776 
777 def from_column_list(
778  col_names, col_types=None,
779  col_blobs=None, col_metadata=None
780 ):
781  """
782  Given a list of names, types, and optionally values, construct a Schema.
783  """
784  if col_types is None:
785  col_types = [None] * len(col_names)
786  if col_metadata is None:
787  col_metadata = [None] * len(col_names)
788  if col_blobs is None:
789  col_blobs = [None] * len(col_names)
790  assert len(col_names) == len(col_types), (
791  'col_names and col_types must have the same length.'
792  )
793  assert len(col_names) == len(col_metadata), (
794  'col_names and col_metadata must have the same length.'
795  )
796  assert len(col_names) == len(col_blobs), (
797  'col_names and col_blobs must have the same length.'
798  )
799  root = _SchemaNode('root', 'Struct')
800  for col_name, col_type, col_blob, col_metadata in zip(
801  col_names, col_types, col_blobs, col_metadata
802  ):
803  columns = col_name.split(FIELD_SEPARATOR)
804  current = root
805  for i in range(len(columns)):
806  name = columns[i]
807  type_str = ''
808  field = None
809  if i == len(columns) - 1:
810  type_str = col_type
811  field = Scalar(
812  dtype=col_type,
813  blob=col_blob,
814  metadata=col_metadata
815  )
816  next = current.add_child(name, type_str)
817  if field is not None:
818  next.field = field
819  next.col_blob = col_blob
820  current = next
821 
822  return root.get_field()
823 
824 
825 def from_blob_list(schema, values):
826  """
827  Create a schema that clones the given schema, but containing the given
828  list of values.
829  """
830  assert isinstance(schema, Field), 'Argument `schema` must be a Field.'
831  if isinstance(values, BlobReference):
832  values = [values]
833  record = schema.clone_schema()
834  scalars = record.all_scalars()
835  assert len(scalars) == len(values), (
836  'Values must have %d elements, got %d.' % (len(scalars), len(values))
837  )
838  for scalar, value in zip(scalars, values):
839  scalar.set_value(value)
840  return record
841 
842 
843 def as_record(value):
844  if isinstance(value, Field):
845  return value
846  elif isinstance(value, list) or isinstance(value, tuple):
847  is_field_list = all(
848  f is tuple and len(f) == 2 and isinstance(f[0], core.basestring)
849  for f in value
850  )
851  if is_field_list:
852  return Struct(* [(k, as_record(v)) for k, v in value])
853  else:
854  return Tuple(* [as_record(f) for f in value])
855  elif isinstance(value, dict):
856  return Struct(* [(k, as_record(v)) for k, v in value.items()])
857  else:
858  return _normalize_field(value)
859 
860 
861 def FetchRecord(blob_record, ws=None):
862  """
863  Given a record containing BlobReferences, return a new record with same
864  schema, containing numpy arrays, fetched from the current active workspace.
865  """
866 
867  def fetch(v):
868  if ws is None:
869  return workspace.FetchBlob(str(v))
870  else:
871  return ws.blobs[str(v)].fetch()
872 
873  assert isinstance(blob_record, Field)
874  field_blobs = blob_record.field_blobs()
875  assert all(isinstance(v, BlobReference) for v in field_blobs)
876  field_arrays = [fetch(value) for value in field_blobs]
877  return from_blob_list(blob_record, field_arrays)
878 
879 
880 def FeedRecord(blob_record, arrays, ws=None):
881  """
882  Given a Record containing blob_references and arrays, which is either
883  a list of numpy arrays or a Record containing numpy arrays, feeds the
884  record to the current workspace.
885  """
886 
887  def feed(b, v):
888  if ws is None:
889  workspace.FeedBlob(str(b), v)
890  else:
891  ws.create_blob(str(b))
892  ws.blobs[str(b)].feed(v)
893 
894  assert isinstance(blob_record, Field)
895  field_blobs = blob_record.field_blobs()
896  assert all(isinstance(v, BlobReference) for v in field_blobs)
897  if isinstance(arrays, Field):
898  # TODO: check schema
899  arrays = arrays.field_blobs()
900  assert len(arrays) == len(field_blobs), (
901  'Values must contain exactly %d ndarrays.' % len(field_blobs)
902  )
903  for blob, array in zip(field_blobs, arrays):
904  feed(blob, array)
905 
906 
907 def NewRecord(net, schema):
908  """
909  Given a record of np.arrays, create a BlobReference for each one of them,
910  returning a record containing BlobReferences. The name of each returned blob
911  is NextScopedBlob(field_name), which guarantees unique name in the current
912  net. Use NameScope explicitly to avoid name conflictions between different
913  nets.
914  """
915  if isinstance(schema, Scalar):
916  result = schema.clone()
917  result.set_value(
918  blob=net.NextScopedBlob('unnamed_scalar')
919  )
920  return result
921 
922  assert isinstance(schema, Field), 'Record must be a schema.Field instance.'
923  blob_refs = [
924  net.NextScopedBlob(prefix=name)
925  for name in schema.field_names()
926  ]
927  return from_blob_list(schema, blob_refs)
928 
929 
930 def ConstRecord(net, array_record):
931  """
932  Given a record of arrays, returns a record of blobs,
933  initialized with net.Const.
934  """
935  blob_record = NewRecord(net, array_record)
936  for blob, array in zip(
937  blob_record.field_blobs(), array_record.field_blobs()
938  ):
939  net.Const(array, blob)
940  return blob_record
941 
942 
943 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
944  if not schema_or_record.has_blobs():
945  record = NewRecord(net, schema_or_record)
946  else:
947  record = schema_or_record
948 
949  for blob_type, blob in zip(record.field_types(), record.field_blobs()):
950  try:
951  data_type = data_type_for_dtype(blob_type)
952  shape = [0] + list(blob_type.shape)
953  net.ConstantFill([], blob, shape=shape, dtype=data_type)
954  except TypeError:
955  # If data_type_for_dtype doesn't know how to resolve given numpy
956  # type to core.DataType, that function can throw type error (for
957  # example that would happen for cases of unknown types such as
958  # np.void). This is not a problem for cases when the record if going
959  # to be overwritten by some operator later, though it might be an
960  # issue for type/shape inference.
961  if enforce_types:
962  raise
963  # If we don't enforce types for all items we'll create a blob with
964  # the default ConstantFill (FLOAT, no shape)
965  net.ConstantFill([], blob, shape=[0])
966 
967  return record
968 
969 
970 _DATA_TYPE_FOR_DTYPE = [
971  (np.str, core.DataType.STRING),
972  (np.float32, core.DataType.FLOAT),
973  (np.float64, core.DataType.DOUBLE),
974  (np.bool, core.DataType.BOOL),
975  (np.int8, core.DataType.INT8),
976  (np.int16, core.DataType.INT16),
977  (np.int32, core.DataType.INT32),
978  (np.int64, core.DataType.INT64),
979  (np.uint8, core.DataType.UINT8),
980  (np.uint16, core.DataType.UINT16),
981 ]
982 
983 
984 def is_schema_subset(schema, original_schema):
985  # TODO add more checks
986  return set(schema.field_names()).issubset(
987  set(original_schema.field_names()))
988 
989 
990 def equal_schemas(schema, original_schema):
991  assert isinstance(schema, Field)
992  assert isinstance(original_schema, Field)
993  # TODO allow for more compatibility
994  return schema.field_names() == original_schema.field_names() and\
995  schema.field_types() == original_schema.field_types()
996 
997 
998 def schema_check(schema, previous=None):
999  record = as_record(schema)
1000  if previous is not None:
1001  assert equal_schemas(schema, previous)
1002  return record
1003 
1004 
1005 def data_type_for_dtype(dtype):
1006  for np_type, dt in _DATA_TYPE_FOR_DTYPE:
1007  if dtype.base == np_type:
1008  return dt
1009  raise TypeError('Unknown dtype: ' + str(dtype.base))
1010 
1011 
1012 def attach_metadata_to_scalars(field, metadata):
1013  for f in field.all_scalars():
1014  f.set_metadata(metadata)
1015 
def __getattr__(self, item)
Definition: schema.py:223
def __call__(self)
Definition: schema.py:534
def NewRecord(net, schema)
Definition: schema.py:908
def _struct_from_nested_name(self, nested_name, field)
Definition: schema.py:298
def __init__(self, children)
Definition: schema.py:89
def __getitem__(self, item)
Definition: schema.py:376
def get_children(self)
Definition: schema.py:311
def all_scalars(self)
Definition: schema.py:121
def field_types(self)
Definition: schema.py:106
def RawTuple(num_fields, name_prefix='field')
Definition: schema.py:666
def set_value(self, blob)
Definition: schema.py:557
def field_names(self)
Definition: schema.py:102
def ConstRecord(net, array_record)
Definition: schema.py:931
def FetchRecord(blob_record, ws=None)
Definition: schema.py:862
def Map(keys, values, keys_name='keys', values_name='values', lengths_blob=None)
Definition: schema.py:644
def from_dtype(dtype, _outer_shape=())
Definition: schema.py:675
def _child_base_id(self, child_index=None)
Definition: schema.py:159
def id(self)
Definition: schema.py:629
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def field_blobs(self)
Definition: schema.py:114
def FeedRecord(blob_record, arrays, ws=None)
Definition: schema.py:881
def slice(self)
Definition: schema.py:137
def set(self, dtype=None, blob=None, metadata=None)
Definition: schema.py:561
def set_metadata(self, value)
Definition: schema.py:542
def clone(self, keep_blobs=True)
Definition: schema.py:130
def _validate_metadata(self)
Definition: schema.py:548
def has_blobs(self)
Definition: schema.py:126
def __eq__(self, other)
Definition: schema.py:167
def FetchBlob(name)
Definition: workspace.py:276
def Tuple(fields)
Definition: schema.py:659
def from_blob_list(schema, values)
Definition: schema.py:826
def field_metadata(self)
Definition: schema.py:110
def __add__(self, other)
Definition: schema.py:408
def __init__(self, fields)
Definition: schema.py:250
def get(self)
Definition: schema.py:529
def _get_field_by_nested_name(self, nested_name)
Definition: schema.py:354
_field_offsets
Definition: schema.py:93
def from_column_list(col_names, col_types=None, col_blobs=None, col_metadata=None)
Definition: schema.py:781