4 Defines a minimal set of data types that allow to represent datasets with 5 arbitrary nested structure, including objects of variable length, such as 8 This defines a columnar storage format for such datasets on top of caffe2 9 tensors. In terms of capacity of representation, it can represent most of 10 the data types supported by Parquet, ORC, DWRF file formats. 12 See comments in operator_test/dataset_ops_test.py for a example and 13 walkthrough on how to use schema to store and iterate through a structured 16 from __future__
import absolute_import
17 from __future__
import division
18 from __future__
import print_function
19 from __future__
import unicode_literals
23 from caffe2.python
import core
24 from caffe2.python
import workspace
25 from caffe2.python.core
import BlobReference
26 from collections
import OrderedDict, namedtuple
28 logger = logging.getLogger(__name__)
29 logger.setLevel(logging.INFO)
34 def _join_field_name(prefix, suffix):
36 return '{}{}{}'.format(prefix, FIELD_SEPARATOR, suffix)
45 def _normalize_field(field_or_type_or_blob, keep_blobs=True):
46 """Clones/normalizes a field before adding it to a container.""" 47 if isinstance(field_or_type_or_blob, Field):
48 return field_or_type_or_blob.clone(keep_blobs=keep_blobs)
49 elif type(field_or_type_or_blob)
in (type, np.dtype):
50 return Scalar(dtype=field_or_type_or_blob)
52 return Scalar(blob=field_or_type_or_blob)
55 FeatureSpec = namedtuple(
57 [
'feature_type',
'feature_names',
'feature_ids',
'feature_is_request_only']
59 FeatureSpec.__new__.__defaults__ = (
None,
None,
None,
None)
63 'Metadata', [
'categorical_limit',
'expected_value',
'feature_specs']
66 """Represents additional information associated with a scalar in schema. 68 `categorical_limit` - for fields of integral type that are guaranteed to be 69 non-negative it specifies the maximum possible value plus one. It's often 70 used as a size of an embedding table. 72 `expected_value` - anticipated average value of elements in the field. 73 Usually makes sense for length fields of lists. 75 `feature_specs` - information about the features that contained in this 76 field. For example if field have more then 1 feature it can have list of 77 feature names contained in this field.""" 81 Metadata.__new__.__defaults__ = (
None,
None,
None,
None)
85 """Represents an abstract field type in a dataset. 89 """Derived classes must call this after their initialization.""" 93 for child
in children:
95 offset += len(child.field_names())
98 def clone_schema(self):
99 return self.
clone(keep_blobs=
False)
102 """Return the children field names for this field.""" 103 raise NotImplementedError(
'Field is an abstract class.')
106 """Return the numpy.dtype for each of the children fields.""" 107 raise NotImplementedError(
'Field is an abstract class.')
110 """Return the Metadata for each of the children fields.""" 111 raise NotImplementedError(
'Field is an abstract class.')
114 """Return the list of blobs with contents for this Field. 115 Values can either be all numpy.ndarray or BlobReference. 116 If any of the fields doens't have a blob, throws. 118 raise NotImplementedError(
'Field is an abstract class.')
121 """Return the list of all Scalar instances in the Field. 122 The order is the same as for field_names() or field_blobs()""" 123 raise NotImplementedError(
'Field is an abstract class.')
126 """Return True if every scalar of this field has blobs.""" 127 raise NotImplementedError(
'Field is an abstract class.')
129 def clone(self, keep_blobs=True):
130 """Clone this Field along with its children.""" 131 raise NotImplementedError(
'Field is an abstract class.')
133 def _set_parent(self, parent, relative_id):
134 self.
_parent = (parent, relative_id)
138 Returns a slice representing the range of field ids that belong to 139 this field. This slice can be used to index a list of fields. 146 >>> ('b1', Scalar()), 147 >>> ('b2', Scalar()), 151 >>> field_data = ['da', 'db1', 'db2', 'dc'] 152 >>> field_data[s.b.split()] 158 def _child_base_id(self, child_index=None):
159 """Get the base id of the given child""" 161 pos = 0
if child_index
is None else self.
_field_offsets[child_index]
163 pos += p._child_base_id(i)
167 """Equivalance of two schemas""" 176 """Represents a variable-length list. 178 Values of a list can also be complex fields such as Lists and Structs. 179 In addition to the fields exposed by its `values` field, a List exposes an 180 additional `lengths` field, which will contain the size of each list under 184 def __init__(self, values, lengths_blob=None):
185 if isinstance(lengths_blob, Field):
186 assert isinstance(lengths_blob, Scalar)
187 self.
lengths = _normalize_field(lengths_blob)
190 self.
_items = _normalize_field(values)
191 self.
lengths._set_parent(self, 0)
192 self.
_items._set_parent(self, 1)
195 def field_names(self):
196 value_fields = self.
_items.field_names()
198 [
'lengths'] + [_join_field_name(
'values', v)
for v
in value_fields]
201 def field_types(self):
202 return self.
lengths.field_types() + self.
_items.field_types()
204 def field_metadata(self):
205 return self.
lengths.field_metadata() + self.
_items.field_metadata()
207 def field_blobs(self):
208 return self.
lengths.field_blobs() + self.
_items.field_blobs()
210 def all_scalars(self):
211 return self.
lengths.all_scalars() + self.
_items.all_scalars()
216 def clone(self, keep_blobs=True):
218 _normalize_field(self.
_items, keep_blobs=keep_blobs),
219 _normalize_field(self.
lengths, keep_blobs=keep_blobs)
223 """If the value of this list is a struct, 224 allow to instrospect directly into its fields.""" 225 if item.startswith(
'__'):
226 raise AttributeError(item)
227 if isinstance(self.
_items, Struct):
228 return getattr(self.
_items, item)
229 elif item ==
'value' or item ==
'items':
232 raise AttributeError(
'Field not found in list: %s.' % item)
234 def __getitem__(self, item):
235 if isinstance(self.
_items, Struct):
237 elif item ==
'lengths':
239 elif item ==
'value' or item ==
'items':
242 raise KeyError(
'Field not found in list: %s.' % item)
246 """Represents a named list of fields sharing the same domain. 250 """ fields is a list of tuples in format of (name, field). The name is 251 a string of nested name, e.g., `a`, `a:b`, `a:b:c`. For example 268 ('d', Struct(('e', Scalar()))), 274 assert len(field) == 2
275 assert field[0],
'Field names cannot be empty' 276 assert field[0] !=
'lengths', (
277 'Struct cannot contain a field named `lengths`.' 279 fields = [(name, _normalize_field(field))
for name, field
in fields]
280 self.
fields = OrderedDict()
281 for name, field
in fields:
282 if FIELD_SEPARATOR
in name:
284 if name
not in self.
fields:
288 not isinstance(field, Struct)
or 289 not isinstance(self.
fields[name], Struct)
291 raise ValueError(
'Duplicate field name: %s' % name)
293 for id, (_, field)
in enumerate(self.
fields.items()):
294 field._set_parent(self, id)
295 Field.__init__(self, self.
fields.values())
297 def _struct_from_nested_name(self, nested_name, field):
298 def create_internal(nested_name, field):
299 names = nested_name.split(FIELD_SEPARATOR, 1)
303 added_field = create_internal(names[1], field)
304 return Struct((names[0], added_field))
306 names = nested_name.split(FIELD_SEPARATOR, 1)
307 assert len(names) >= 2
308 return names[0], create_internal(names[1], field)
310 def get_children(self):
311 return self.
fields.items()
313 def field_names(self):
315 for name, field
in self.
fields.items():
316 names += [_join_field_name(name, f)
for f
in field.field_names()]
319 def field_types(self):
321 for _, field
in self.
fields.items():
322 types += field.field_types()
325 def field_metadata(self):
327 for _, field
in self.
fields.items():
328 metadata += field.field_metadata()
331 def field_blobs(self):
333 for _, field
in self.
fields.items():
334 blobs += field.field_blobs()
337 def all_scalars(self):
339 for _, field
in self.
fields.items():
340 scalars += field.all_scalars()
344 return all(field.has_blobs()
for field
in self.
fields.values())
346 def clone(self, keep_blobs=True):
347 normalized_fields = [
348 (k, _normalize_field(v, keep_blobs=keep_blobs))
349 for k, v
in self.
fields.items()
351 return Struct(*normalized_fields)
353 def _get_field_by_nested_name(self, nested_name):
354 names = nested_name.split(FIELD_SEPARATOR, 1)
355 field = self.
fields.get(names[0],
None)
364 return field[names[1]]
365 except (KeyError, TypeError):
368 def __contains__(self, item):
370 return field
is not None 377 item can be a tuple or list of ints or strings, or a single 378 int or string. String item is a nested field name, e.g., "a", "a:b", 379 "a:b:c". Int item is the index of a field at the first level of the 382 if isinstance(item, list)
or isinstance(item, tuple):
387 if isinstance(k, int)
else k, self[k]
391 elif isinstance(item, int):
392 return self.
fields.values()[item]
396 raise KeyError(
'field "%s" not found' % (item))
399 def __getattr__(self, item):
400 if item.startswith(
'__'):
401 raise AttributeError(item)
403 return self.__dict__[
'fields'][item]
405 raise AttributeError(item)
409 Allows to merge fields of two schema.Struct using '+' operator. 410 If two Struct have common field names, the merge is conducted 411 recursively. Here are examples: 414 s1 = Struct(('a', Scalar())) 415 s2 = Struct(('b', Scalar())) 424 ('b', Struct(('c', Scalar()))), 426 s2 = Struct(('b', Struct(('d', Scalar())))) 435 if not isinstance(other, Struct):
436 return NotImplemented
439 for name, right_field
in other.get_children():
440 if name
not in children:
441 children[name] = right_field
443 left_field = children[name]
444 children[name] = left_field + right_field
446 return Struct(*(children.items()))
450 """Represents a typed scalar or tensor of fixed shape. 452 A Scalar is a leaf in a schema tree, translating to exactly one tensor in 453 the dataset's underlying storage. 455 Usually, the tensor storing the actual values of this field is a 1D tensor, 456 representing a series of values in its domain. It is possible however to 457 have higher rank values stored as a Scalar, as long as all entries have 464 Scalar field of type float32. Caffe2 will expect readers and 465 datasets to expose it as a 1D tensor of doubles (vector), where 466 the size of the vector is determined by this fields' domain. 468 Scalar((np.int32, 5)) 470 Tensor field of type int32. Caffe2 will expect readers and 471 datasets to implement it as a 2D tensor (matrix) of shape (L, 5), 472 where L is determined by this fields' domain. 474 Scalar((str, (10, 20))) 476 Tensor field of type str. Caffe2 will expect readers and 477 datasets to implement it as a 3D tensor of shape (L, 10, 20), 478 where L is determined by this fields' domain. 480 If the field type is unknown at construction time, call Scalar(), that will 481 default to np.void as its dtype. 483 It is an error to pass a structured dtype to Scalar, since it would contain 484 more than one field. Instead, use from_dtype, which will construct 485 a nested `Struct` field reflecting the given dtype's structure. 487 A Scalar can also contain a blob, which represents the value of this 488 Scalar. A blob can be either a numpy.ndarray, in which case it contain the 489 actual contents of the Scalar, or a BlobReference, which represents a 490 blob living in a caffe2 Workspace. If blob of different types are passed, 491 a conversion to numpy.ndarray is attempted. 494 def __init__(self, dtype=None, blob=None, metadata=None):
496 self.
set(dtype, blob, metadata)
497 Field.__init__(self, [])
499 def field_names(self):
502 def field_type(self):
505 def field_types(self):
508 def field_metadata(self):
512 return self.
_blob is not None 514 def field_blobs(self):
515 assert self.
_blob is not None,
'Value is not set for this field.' 518 def all_scalars(self):
521 def clone(self, keep_blobs=True):
524 blob=self.
_blob if keep_blobs
else None,
529 """Gets the current blob of this Scalar field.""" 530 assert self.
_blob is not None,
'Value is not set for this field.' 534 """Shortcut for self.get()""" 541 def set_metadata(self, value):
542 assert isinstance(value, Metadata), \
543 'metadata must be Metadata, got {}'.format(type(value))
547 def _validate_metadata(self):
550 if (self.
_metadata.categorical_limit
is not None and 551 self.
dtype is not None):
552 assert np.issubdtype(self.
dtype, np.integer), \
553 "`categorical_limit` can be specified only in integral " + \
554 "fields but got {}".format(self.
dtype)
557 """Sets only the blob field still validating the existing dtype""" 560 def set(self, dtype=None, blob=None, metadata=None):
561 """Set the type and/or blob of this scalar. See __init__ for details. 564 dtype: can be any numpy type. If not provided and `blob` is 565 provided, it will be inferred. If no argument is provided, 566 this Scalar will be of type np.void. 567 blob: if provided, can be either a BlobReference or a 568 numpy.ndarray. If a value of different type is passed, 569 a conversion to numpy.ndarray is attempted. Strings aren't 570 accepted, since they can be ambiguous. If you want to pass 571 a string, to either BlobReference(blob) or np.array(blob). 572 metadata: optional instance of Metadata, if provided overrides 573 the metadata information of the scalar 575 if blob
is not None and isinstance(blob, core.basestring):
577 'Passing str blob to Scalar.set() is ambiguous. ' 578 'Do either set(blob=np.array(blob)) or ' 579 'set(blob=BlobReference(blob))' 583 if dtype
is not None:
584 dtype = np.dtype(dtype)
587 if blob
is not None and not isinstance(blob, BlobReference):
588 if dtype
is not None and dtype != np.void:
589 blob = np.array(blob, dtype=dtype.base)
592 blob = blob.reshape((0, ) + dtype.shape)
594 assert isinstance(blob, np.ndarray), (
595 'Invalid blob type: %s' % str(type(blob)))
599 if len(blob.shape) == 0:
600 blob = blob.reshape((1, ))
604 if (len(blob.shape) > 1
and dtype
is not None and 605 dtype.base != np.void):
606 dtype = np.dtype((dtype.base, blob.shape[1:]))
609 dtype = np.dtype(np.void)
610 assert not dtype.fields, (
611 'Cannot create Scalar with a structured dtype. ' +
612 'Use from_dtype instead.' 616 if metadata
is not None:
620 def set_type(self, dtype):
622 if dtype
is not None:
623 self.
dtype = np.dtype(dtype)
625 self.
dtype = np.dtype(np.void)
630 Return the zero-indexed position of this scalar field in its schema. 631 Used in order to index into the field_blob list returned by readers or 641 values_name='values',
644 """A map is a List of Struct containing keys and values fields. 645 Optionally, you can provide custom name for the key and value fields. 648 Struct((keys_name, keys), (values_name, values)),
649 lengths_blob=lengths_blob
653 def NamedTuple(name_prefix, *fields):
654 return Struct(* [(
'%s_%d' % (name_prefix, i), field)
655 for i, field
in enumerate(fields)])
660 Creates a Struct with default, sequential, field names of given types. 662 return NamedTuple(
'field', *fields)
665 def RawTuple(num_fields, name_prefix='field'):
667 Creates a tuple of `num_field` untyped scalars. 669 assert isinstance(num_fields, int)
670 assert num_fields >= 0
671 return NamedTuple(name_prefix, *([np.void] * num_fields))
675 """Constructs a Caffe2 schema from the given numpy's dtype. 677 Numpy supports scalar, array-like and structured datatypes, as long as 678 all the shapes are fixed. This function breaks down the given dtype into 679 a Caffe2 schema containing `Struct` and `Scalar` types. 681 Fields containing byte offsets are not currently supported. 683 if not isinstance(dtype, np.dtype):
686 dtype = np.dtype((dtype, _outer_shape))
689 shape = _outer_shape + dtype.shape
690 if shape != dtype.shape:
691 dtype = np.dtype((dtype.base, shape))
697 for name, (fdtype, offset)
in dtype.fields:
698 assert offset == 0, (
'Fields with byte offsets are not supported.')
699 struct_fields += (name,
from_dtype(fdtype, _outer_shape=shape))
700 return Struct(*struct_fields)
704 """This is a private class used to represent a Schema Node""" 706 def __init__(self, name, type_str=''):
713 def add_child(self, name, type_str=''):
715 if child.name == name
and child.type_str == type_str:
723 list_names = [
'lengths',
'values']
724 map_names = [
'lengths',
'keys',
'values']
727 assert self.
field is not None 732 child_names.append(child.name)
734 if (set(child_names) == set(list_names)):
736 if child.name ==
'values':
739 lengths_blob=self.
children[0].col_blob
743 elif (set(child_names) == set(map_names)):
745 if child.name ==
'keys':
746 key_field = child.get_field()
747 elif child.name ==
'values':
748 values_field = child.get_field()
752 lengths_blob=self.
children[0].col_blob
760 if child.field
is not None:
761 struct_fields.append((child.name, child.field))
763 struct_fields.append((child.name, child.get_field()))
769 def print_recursively(self):
771 child.print_recursively()
772 logger.info(
"Printing node: Name and type")
773 logger.info(self.
name)
778 col_names, col_types=None,
779 col_blobs=None, col_metadata=None
782 Given a list of names, types, and optionally values, construct a Schema. 784 if col_types
is None:
785 col_types = [
None] * len(col_names)
786 if col_metadata
is None:
787 col_metadata = [
None] * len(col_names)
788 if col_blobs
is None:
789 col_blobs = [
None] * len(col_names)
790 assert len(col_names) == len(col_types), (
791 'col_names and col_types must have the same length.' 793 assert len(col_names) == len(col_metadata), (
794 'col_names and col_metadata must have the same length.' 796 assert len(col_names) == len(col_blobs), (
797 'col_names and col_blobs must have the same length.' 800 for col_name, col_type, col_blob, col_metadata
in zip(
801 col_names, col_types, col_blobs, col_metadata
803 columns = col_name.split(FIELD_SEPARATOR)
805 for i
in range(len(columns)):
809 if i == len(columns) - 1:
814 metadata=col_metadata
816 next = current.add_child(name, type_str)
817 if field
is not None:
819 next.col_blob = col_blob
822 return root.get_field()
827 Create a schema that clones the given schema, but containing the given 830 assert isinstance(schema, Field),
'Argument `schema` must be a Field.' 831 if isinstance(values, BlobReference):
833 record = schema.clone_schema()
834 scalars = record.all_scalars()
835 assert len(scalars) == len(values), (
836 'Values must have %d elements, got %d.' % (len(scalars), len(values))
838 for scalar, value
in zip(scalars, values):
839 scalar.set_value(value)
843 def as_record(value):
844 if isinstance(value, Field):
846 elif isinstance(value, list)
or isinstance(value, tuple):
848 f
is tuple
and len(f) == 2
and isinstance(f[0], core.basestring)
852 return Struct(* [(k, as_record(v))
for k, v
in value])
854 return Tuple(* [as_record(f)
for f
in value])
855 elif isinstance(value, dict):
856 return Struct(* [(k, as_record(v))
for k, v
in value.items()])
858 return _normalize_field(value)
863 Given a record containing BlobReferences, return a new record with same 864 schema, containing numpy arrays, fetched from the current active workspace. 871 return ws.blobs[str(v)].fetch()
873 assert isinstance(blob_record, Field)
874 field_blobs = blob_record.field_blobs()
875 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
876 field_arrays = [fetch(value)
for value
in field_blobs]
882 Given a Record containing blob_references and arrays, which is either 883 a list of numpy arrays or a Record containing numpy arrays, feeds the 884 record to the current workspace. 891 ws.create_blob(str(b))
892 ws.blobs[str(b)].feed(v)
894 assert isinstance(blob_record, Field)
895 field_blobs = blob_record.field_blobs()
896 assert all(isinstance(v, BlobReference)
for v
in field_blobs)
897 if isinstance(arrays, Field):
899 arrays = arrays.field_blobs()
900 assert len(arrays) == len(field_blobs), (
901 'Values must contain exactly %d ndarrays.' % len(field_blobs)
903 for blob, array
in zip(field_blobs, arrays):
909 Given a record of np.arrays, create a BlobReference for each one of them, 910 returning a record containing BlobReferences. The name of each returned blob 911 is NextScopedBlob(field_name), which guarantees unique name in the current 912 net. Use NameScope explicitly to avoid name conflictions between different 915 if isinstance(schema, Scalar):
916 result = schema.clone()
918 blob=net.NextScopedBlob(
'unnamed_scalar')
922 assert isinstance(schema, Field),
'Record must be a schema.Field instance.' 924 net.NextScopedBlob(prefix=name)
925 for name
in schema.field_names()
932 Given a record of arrays, returns a record of blobs, 933 initialized with net.Const. 935 blob_record =
NewRecord(net, array_record)
936 for blob, array
in zip(
937 blob_record.field_blobs(), array_record.field_blobs()
939 net.Const(array, blob)
943 def InitEmptyRecord(net, schema_or_record, enforce_types=False):
944 if not schema_or_record.has_blobs():
945 record =
NewRecord(net, schema_or_record)
947 record = schema_or_record
949 for blob_type, blob
in zip(record.field_types(), record.field_blobs()):
951 data_type = data_type_for_dtype(blob_type)
952 shape = [0] + list(blob_type.shape)
953 net.ConstantFill([], blob, shape=shape, dtype=data_type)
965 net.ConstantFill([], blob, shape=[0])
970 _DATA_TYPE_FOR_DTYPE = [
971 (np.str, core.DataType.STRING),
972 (np.float32, core.DataType.FLOAT),
973 (np.float64, core.DataType.DOUBLE),
974 (np.bool, core.DataType.BOOL),
975 (np.int8, core.DataType.INT8),
976 (np.int16, core.DataType.INT16),
977 (np.int32, core.DataType.INT32),
978 (np.int64, core.DataType.INT64),
979 (np.uint8, core.DataType.UINT8),
980 (np.uint16, core.DataType.UINT16),
984 def is_schema_subset(schema, original_schema):
986 return set(schema.field_names()).issubset(
987 set(original_schema.field_names()))
990 def equal_schemas(schema, original_schema):
991 assert isinstance(schema, Field)
992 assert isinstance(original_schema, Field)
994 return schema.field_names() == original_schema.field_names()
and\
995 schema.field_types() == original_schema.field_types()
998 def schema_check(schema, previous=None):
999 record = as_record(schema)
1000 if previous
is not None:
1001 assert equal_schemas(schema, previous)
1005 def data_type_for_dtype(dtype):
1006 for np_type, dt
in _DATA_TYPE_FOR_DTYPE:
1007 if dtype.base == np_type:
1009 raise TypeError(
'Unknown dtype: ' + str(dtype.base))
1012 def attach_metadata_to_scalars(field, metadata):
1013 for f
in field.all_scalars():
1014 f.set_metadata(metadata)
1015
def __getattr__(self, item)
def NewRecord(net, schema)
def _struct_from_nested_name(self, nested_name, field)
def __init__(self, children)
def __getitem__(self, item)
def RawTuple(num_fields, name_prefix='field')
def set_value(self, blob)
def ConstRecord(net, array_record)
def FetchRecord(blob_record, ws=None)
def Map(keys, values, keys_name='keys', values_name='values', lengths_blob=None)
def from_dtype(dtype, _outer_shape=())
def _child_base_id(self, child_index=None)
def FeedBlob(name, arr, device_option=None)
def FeedRecord(blob_record, arrays, ws=None)
def set(self, dtype=None, blob=None, metadata=None)
def set_metadata(self, value)
def clone(self, keep_blobs=True)
def _validate_metadata(self)
def from_blob_list(schema, values)
def __init__(self, fields)
def _get_field_by_nested_name(self, nested_name)
def from_column_list(col_names, col_types=None, col_blobs=None, col_metadata=None)