3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core, workspace, cnn
13 This is an abstract base class. 14 Don't inherit from unittest.TestCase, and don't name it 'Test*'. 15 Do, however, do these things in classes which inherit from this. 19 perfect_model = np.array([2, 6, 5, 0, 1]).astype(np.float32)
21 data = np.random.randint(
23 size=(20, perfect_model.size)).astype(np.float32)
24 label = np.dot(data, perfect_model)[:, np.newaxis]
28 'data',
'fc', perfect_model.size, 1, (
'ConstantFill', {}),
29 (
'ConstantFill', {}), axis=0
31 sq = model.SquaredL2Distance([out,
'label'])
32 loss = model.AveragedLoss(sq,
"avg_loss")
33 grad_map = model.AddGradientOperators([loss])
35 optimizer = self.build_optimizer(model)
42 idx = np.random.randint(data.shape[0])
47 np.testing.assert_allclose(
48 perfect_model[np.newaxis, :],
52 self.check_optimizer(optimizer)
58 perfect_model = np.array([2, 6, 5, 0, 1]).astype(np.float32)
60 data = np.random.randint(
62 size=(20, perfect_model.size * DUPLICATION)).astype(np.float32)
63 label = np.dot(data, np.repeat(perfect_model, DUPLICATION))
67 w = model.param_init_net.ConstantFill(
68 [],
'w', shape=[perfect_model.size], value=0.0)
69 model.params.append(w)
70 picked = model.net.Gather([w,
'indices'],
'gather')
71 out = model.ReduceFrontSum(picked,
'sum')
73 sq = model.SquaredL2Distance([out,
'label'])
74 loss = model.AveragedLoss(sq,
"avg_loss")
75 grad_map = model.AddGradientOperators([loss])
76 self.assertIsInstance(grad_map[
'w'], core.GradientSlice)
77 optimizer = self.build_optimizer(model)
82 for indices_type
in [np.int32, np.int64]:
86 idx = np.random.randint(data.shape[0])
88 indices = np.repeat(np.arange(perfect_model.size),
89 DUPLICATION)[data[idx] == 1]
94 indices.reshape((indices.size,)).astype(indices_type)
97 np.array(label[idx]).astype(np.float32))
100 np.testing.assert_allclose(
105 self.check_optimizer(optimizer)
def RunNet(name, num_iter=1)
def FeedBlob(name, arr, device_option=None)
def CreateNet(net, overwrite=False, input_blobs=None)