Caffe2 - Python API
A deep learning, cross platform ML framework
batch_sigmoid_cross_entropy_loss.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import schema
9 from caffe2.python.layers.layers import ModelLayer
10 from caffe2.python.layers.tags import Tags
11 import numpy as np
12 
13 
14 class BatchSigmoidCrossEntropyLoss(ModelLayer):
15  def __init__(
16  self,
17  model,
18  input_record,
19  name='batch_sigmoid_cross_entropy_loss',
20  **kwargs
21  ):
22  super(BatchSigmoidCrossEntropyLoss, self).__init__(
23  model, name, input_record, **kwargs)
24 
27  ('label', schema.Scalar(np.float32)),
28  ('prediction', schema.Scalar(np.float32)),
29  ),
30  input_record
31  )
32  assert input_record.prediction.field_type().shape == \
33  input_record.label.field_type().shape, \
34  "prediction and label must have the same shape"
35 
36  self.tags.update({Tags.TRAIN_ONLY})
37 
39  (np.float32, tuple()), model.net.NextScopedBlob(name + '_loss')
40  )
41 
42  def add_ops(self, net):
43  sigmoid_cross_entropy = net.SigmoidCrossEntropyWithLogits(
44  [self.input_record.prediction(), self.input_record.label()],
45  net.NextScopedBlob('sigmoid_cross_entropy')
46  )
47 
48  net.AveragedLoss(
49  sigmoid_cross_entropy, self.output_schema.field_blobs())
def is_schema_subset(schema, original_schema)
Definition: schema.py:985