Caffe2 - Python API
A deep learning, cross platform ML framework
sparse_lookup.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import core, schema
9 from caffe2.python.layers.layers import (
10  get_categorical_limit,
11  IdList,
12  IdScoreList,
13  LayerParameter,
14  LayerPsParam,
15  ModelLayer,
16 )
17 import functools
18 import math
19 import numpy as np
20 import operator
21 
22 
23 class SparseLookup(ModelLayer):
24  _supported_reducers = ['PositionWeighted', 'LogMeanExp', 'LogSumExp', 'Max',
25  'Mean', 'Sum', 'Sqrt']
26 
27  def __init__(self, model, input_record, inner_shape, reducer,
28  weight_init=None, weight_optim=None,
29  name='sparse_lookup', **kwargs):
30  super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
31 
32  if isinstance(inner_shape, int):
33  inner_shape = [inner_shape]
34  assert isinstance(inner_shape, list) or isinstance(inner_shape, tuple),\
35  "Unexpected type for inner_shape, expected list or tuple, got {0}".\
36  format(type(inner_shape))
37 
38  # TODO Add some asserts about input type
39  assert reducer in self._supported_reducers, "Unsupported reducer: {}".\
40  format(reducer)
41  self.reducer = reducer
42 
43  input_dim = get_categorical_limit(input_record)
44 
45  assert input_dim is not None, "Unbounded features are not supported"
46 
48  (np.float32, inner_shape),
49  model.net.NextScopedBlob(name + '_output'),
50  )
51 
52  if self.request_only:
54  self.output_schema,
56  categorical_limit=None,
57  expected_value=None,
58  feature_specs=schema.FeatureSpec(
59  feature_is_request_only=True
60  )
61  )
62  )
63  scale = math.sqrt(1.0 / input_dim)
64  self.shape = [input_dim] + inner_shape
65  self.weight_init = weight_init if weight_init else (
66  'UniformFill', {'min': -scale, 'max': scale})
67 
68  self.w = model.net.NextScopedBlob(name + "_w")
69  if schema.equal_schemas(self.input_record, IdList):
70  sparse_key = self.input_record.items()
71  elif schema.equal_schemas(self.input_record, IdScoreList):
72  sparse_key = self.input_record.keys()
73  else:
74  raise NotImplementedError()
75 
76  if self.input_record.lengths.metadata:
77  avg_length = self.input_record.lengths.metadata.expected_value
78  else:
79  avg_length = None
80  self.params.append(
81  LayerParameter(
82  parameter=self.w,
83  initializer=core.CreateOperator(self.weight_init[0],
84  [],
85  self.w,
86  shape=self.shape,
87  **self.weight_init[1]
88  ),
89  optimizer=weight_optim,
90  ps_param=LayerPsParam(
91  sparse_key=sparse_key,
92  average_length=avg_length
93  )
94  ))
95 
96  if reducer == 'PositionWeighted':
97  self.pos_w = model.net.NextScopedBlob(name + "_pos_w")
98  self.params.append(
99  LayerParameter(
100  parameter=self.pos_w,
101  initializer=core.CreateOperator('ConstantFill',
102  [],
103  self.pos_w,
104  shape=[input_dim, ],
105  value=1.0
106  ),
107  optimizer=weight_optim
108  ))
109 
110  def get_memory_usage(self):
111  return functools.reduce(operator.mul, self.shape) * 4
112 
113  def get_fp16_compatible_parameters(self):
114  return [self.w]
115 
116  def add_ops(self, net):
117  if schema.equal_schemas(self.input_record, IdList):
118  if self.reducer == 'Sum':
119  net.SparseLengthsSum(
120  [
121  self.w,
122  self.input_record.items(),
123  self.input_record.lengths()
124  ],
125  self.output_schema.field_blobs(),
126  engine='fp16'
127  )
128  elif self.reducer == 'PositionWeighted':
129  inc_seq = net.LengthsRangeFill(
130  [self.input_record.lengths()],
131  self.input_record.lengths() + '_seq'
132  )
133  gather_pos_w = net.Gather(
134  [self.pos_w, inc_seq], self.pos_w + '_gather')
135 
136  net.SparseLengthsWeightedSum(
137  [
138  self.w,
139  gather_pos_w,
140  self.input_record.items(),
141  self.input_record.lengths()
142  ],
143  self.output_schema.field_blobs(),
144  grad_on_weights=1,
145  engine='fp16'
146  )
147  elif self.reducer == 'Sqrt':
148  sqrt_weight = net.LengthsToWeights(
149  [self.input_record.lengths()],
150  [self.input_record.lengths() + '_sqrt'],
151  power=0.5
152  )
153  net.SparseLengthsWeightedSum(
154  [
155  self.w,
156  sqrt_weight,
157  self.input_record.items(),
158  self.input_record.lengths()
159  ],
160  self.output_schema.field_blobs(),
161  engine='fp16'
162  )
163  else:
164  table_rows = net.Gather([self.w, self.input_record.items()])
165  segment_ids = net.LengthsToSegmentIds(
166  self.input_record.lengths(),
167  self.input_record.lengths() + '_sid')
168  net.__getattr__('SortedSegmentRange' + self.reducer)(
169  [table_rows, segment_ids],
170  self.output_schema.field_blobs(),
171  engine='fp16'
172  )
173  elif schema.equal_schemas(self.input_record, IdScoreList):
174  if self.reducer == 'Sum':
175  net.SparseLengthsWeightedSum(
176  [
177  self.w,
178  self.input_record.values(),
179  self.input_record.keys(),
180  self.input_record.lengths()
181  ],
182  self.output_schema.field_blobs(),
183  engine='fp16'
184  )
185  else:
186  raise "Only Sum is supported for IdScoreList input." +\
187  "Trying to create with {}".format(self.reducer)
188  else:
189  raise "Unsupported input type {0}".format(self.input_record)
FeatureSpec
Definition: schema.py:55
def input_record(self)
Definition: layers.py:149
def attach_metadata_to_scalars(field, metadata)
Definition: schema.py:1013
def equal_schemas(schema, original_schema)
Definition: schema.py:991
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)
Definition: core.py:259