3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from collections
import namedtuple
10 from caffe2.python
import (
12 layer_model_instantiator,
21 OpSpec = namedtuple(
"OpSpec",
"type input output")
27 super(LayersTestCase, self).setUp()
35 input_feature_schema=input_feature_schema,
36 trainer_extra_schema=trainer_extra_schema)
38 def new_record(self, schema_obj):
44 layer_model_instantiator.generate_training_nets_forward_only() 45 here because it includes initialization of global constants, which make 49 train_init_net =
core.Net(
'train_init_net')
50 for layer
in self.
model.layers:
51 layer.add_operators(train_net, train_init_net)
52 return train_init_net, train_net
54 def get_predict_net(self):
57 def run_train_net(self):
59 train_init_net, train_net = \
66 spec_blobs can either be None or a list of blob names. If it's None, 67 then no assertion is performed. The elements of the list can be None, 68 in that case, it means that position will not be checked. 70 if spec_blobs
is None:
72 self.assertEqual(len(spec_blobs), len(op_blobs))
73 for spec_blob, op_blob
in zip(spec_blobs, op_blobs):
76 self.assertEqual(spec_blob, op_blob)
80 Given a net and a list of OpSpec's, check that the net match the spec 83 self.assertEqual(len(op_specs), len(ops))
84 for op, op_spec
in zip(ops, op_specs):
85 self.assertEqual(op_spec.type, op.type)
def generate_training_nets(model, include_tags=None)
def NewRecord(net, schema)
def get_training_nets(self)
def assertBlobsEqual(self, spec_blobs, op_blobs)
def assertNetContainOps(self, net, op_specs)
def generate_predict_net(model, include_tags=None)