Caffe2 - Python API
A deep learning, cross platform ML framework
sampling_train.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import ModelLayer, get_layer_class
10 from caffe2.python.layers.sampling_trainable_mixin import SamplingTrainableMixin
11 
12 
13 class SamplingTrain(ModelLayer):
14  def __init__(
15  self,
16  model,
17  input_record,
18  prediction_layer,
19  output_dims,
20  subtract_log_odd=True,
21  name='sampling_train',
22  **kwargs
23  ):
24  super(SamplingTrain, self).__init__(
25  model, name, input_record, **kwargs
26  )
27 
28  layer_class = get_layer_class(prediction_layer)
29  assert issubclass(layer_class, SamplingTrainableMixin)
30 
33  ('indices', schema.Scalar()),
34  ('input', schema.Scalar()),
35  ),
36  input_record
37  )
38  self.subtract_log_odd = subtract_log_odd
39  if self.subtract_log_odd:
40  assert 'sampling_prob' in input_record
41 
42  self._prediction_layer = layer_class(
43  model,
44  input_record.input,
45  output_dims=output_dims,
46  **kwargs
47  )
48 
49  self._prediction_layer.train_param_blobs = [
50  model.net.NextBlob(str(blob) + '_sampled')
51  for blob in self._prediction_layer.param_blobs
52  ]
53 
54  self.params = self._prediction_layer.params
55 
56  self.output_schema = self._prediction_layer.output_schema
57 
58  def add_ops(self, net):
59  self._prediction_layer.add_ops(net)
60 
61  def add_train_ops(self, net):
62  for full_blob, sampled_blob in zip(
63  self._prediction_layer.param_blobs,
64  self._prediction_layer.train_param_blobs
65  ):
66  net.Gather([full_blob, self.input_record.indices()], sampled_blob)
67  self._prediction_layer.add_train_ops(net)
68  if not self.subtract_log_odd:
69  return
70  log_q = net.Log(self.input_record.sampling_prob(),
71  net.NextScopedBlob("log_q"))
72  net.Sub([self.output_schema(), log_q], self.output_schema(),
73  broadcast=1, use_grad_hack=1)
def is_schema_subset(schema, original_schema)
Definition: schema.py:985