Caffe2 - Python API
A deep learning, cross platform ML framework
expand_dims.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  ModelLayer,
11 )
12 
13 
14 class ExpandDims(ModelLayer):
15 
16  def __init__(self, model, input_record, dims,
17  name='expand_dims', **kwargs):
18  super(ExpandDims, self).__init__(model, name, input_record, **kwargs)
19  self.dims = dims
20  # Assume that first dimension is batch, so actual dims[i] in shape is
21  # dims[i] - 1
22  dims = [d - 1 for d in dims]
23  assert all([d >= 0 for d in dims])
24  assert isinstance(input_record, schema.Scalar),\
25  "Incorrect input type. Excpected Scalar, but received: {0}".\
26  format(input_record)
27 
28  input_dims = list(input_record.field_type().shape)
29  dims = sorted(set(dims))
30  assert len(input_dims) + len(dims) >= dims[-1] + 1
31 
32  output_dims = input_dims[:]
33  for dim in dims:
34  output_dims.insert(dim, 1)
35 
37  (input_record.field_type().base, output_dims),
38  model.net.NextScopedBlob(name + '_output'))
39 
40  def add_ops(self, net):
41  net.ExpandDims(
42  self.input_record(),
43  self.output_schema(),
44  dims=self.dims,
45  )
def input_record(self)
Definition: layers.py:149