3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import schema
9 from caffe2.python.layers.layers
import ModelLayer, get_layer_class
10 from caffe2.python.layers.sampling_trainable_mixin
import SamplingTrainableMixin
20 subtract_log_odd=True,
21 name='sampling_train',
24 super(SamplingTrain, self).__init__(
25 model, name, input_record, **kwargs
28 layer_class = get_layer_class(prediction_layer)
29 assert issubclass(layer_class, SamplingTrainableMixin)
40 assert 'sampling_prob' in input_record
45 output_dims=output_dims,
50 model.net.NextBlob(str(blob) +
'_sampled')
58 def add_ops(self, net):
61 def add_train_ops(self, net):
62 for full_blob, sampled_blob
in zip(
66 net.Gather([full_blob, self.input_record.indices()], sampled_blob)
70 log_q = net.Log(self.input_record.sampling_prob(),
71 net.NextScopedBlob(
"log_q"))
73 broadcast=1, use_grad_hack=1)
def is_schema_subset(schema, original_schema)