3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core, schema
9 from caffe2.python.layers.layers
import (
10 get_categorical_limit,
24 _supported_reducers = [
'PositionWeighted',
'LogMeanExp',
'LogSumExp',
'Max',
25 'Mean',
'Sum',
'Sqrt']
27 def __init__(self, model, input_record, inner_shape, reducer,
28 weight_init=None, weight_optim=None,
29 name='sparse_lookup', **kwargs):
30 super(SparseLookup, self).__init__(model, name, input_record, **kwargs)
32 if isinstance(inner_shape, int):
33 inner_shape = [inner_shape]
34 assert isinstance(inner_shape, list)
or isinstance(inner_shape, tuple),\
35 "Unexpected type for inner_shape, expected list or tuple, got {0}".\
36 format(type(inner_shape))
43 input_dim = get_categorical_limit(input_record)
45 assert input_dim
is not None,
"Unbounded features are not supported" 48 (np.float32, inner_shape),
49 model.net.NextScopedBlob(name +
'_output'),
56 categorical_limit=
None,
59 feature_is_request_only=
True 63 scale = math.sqrt(1.0 / input_dim)
64 self.
shape = [input_dim] + inner_shape
65 self.
weight_init = weight_init
if weight_init
else (
66 'UniformFill', {
'min': -scale,
'max': scale})
68 self.
w = model.net.NextScopedBlob(name +
"_w")
74 raise NotImplementedError()
77 avg_length = self.
input_record.lengths.metadata.expected_value
89 optimizer=weight_optim,
90 ps_param=LayerPsParam(
91 sparse_key=sparse_key,
92 average_length=avg_length
96 if reducer ==
'PositionWeighted':
97 self.
pos_w = model.net.NextScopedBlob(name +
"_pos_w")
100 parameter=self.
pos_w,
107 optimizer=weight_optim
110 def get_memory_usage(self):
111 return functools.reduce(operator.mul, self.
shape) * 4
113 def get_fp16_compatible_parameters(self):
116 def add_ops(self, net):
119 net.SparseLengthsSum(
128 elif self.
reducer ==
'PositionWeighted':
129 inc_seq = net.LengthsRangeFill(
133 gather_pos_w = net.Gather(
134 [self.
pos_w, inc_seq], self.
pos_w +
'_gather')
136 net.SparseLengthsWeightedSum(
148 sqrt_weight = net.LengthsToWeights(
153 net.SparseLengthsWeightedSum(
165 segment_ids = net.LengthsToSegmentIds(
168 net.__getattr__(
'SortedSegmentRange' + self.
reducer)(
169 [table_rows, segment_ids],
175 net.SparseLengthsWeightedSum(
186 raise "Only Sum is supported for IdScoreList input." +\
187 "Trying to create with {}".format(self.
reducer)
189 raise "Unsupported input type {0}".format(self.
input_record)
def attach_metadata_to_scalars(field, metadata)
def equal_schemas(schema, original_schema)
def CreateOperator(operator_type, inputs, outputs, name='', control_input=None, device_option=None, arg=None, engine=None, kwargs)