3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import core
9 from caffe2.python.layers.layers
import InstantiationContext
10 from caffe2.python.layers.tags
import Tags
13 def _filter_layers(layers, include_tags):
14 if include_tags
is None:
16 include_tags = set(include_tags)
17 return filter(
lambda l:
not include_tags.isdisjoint(l.tags), layers)
20 def generate_predict_net(model, include_tags=None):
21 predict_net =
core.Net(
'predict_net')
23 for layer
in _filter_layers(model.layers, include_tags):
24 if Tags.TRAIN_ONLY
not in layer.tags:
26 predict_net, context=InstantiationContext.PREDICTION)
30 def generate_eval_net(model, include_tags=None):
33 for layer
in _filter_layers(model.layers, include_tags):
34 layer.add_operators(eval_net, context=InstantiationContext.EVAL)
36 input_schema = model.input_feature_schema + model.trainer_extra_schema
37 output_schema = model.output_schema + model.metrics_schema
38 eval_net.set_input_record(input_schema)
39 eval_net.set_output_record(output_schema)
43 def _generate_training_net_only(model, include_tags=None):
45 train_init_net = model.create_init_net(
'train_init_net')
47 for layer
in _filter_layers(model.layers, include_tags):
48 layer.add_operators(train_net, train_init_net)
50 input_schema = model.input_feature_schema + model.trainer_extra_schema
51 output_schema = model.output_schema + model.metrics_schema
52 train_net.set_input_record(input_schema)
53 train_net.set_output_record(output_schema)
54 return train_init_net, train_net
57 def generate_training_nets_forward_only(model, include_tags=None):
58 train_init_net, train_net = _generate_training_net_only(model, include_tags)
59 return train_init_net, train_net
62 def generate_training_nets(model, include_tags=None):
63 train_init_net, train_net = _generate_training_net_only(model, include_tags)
66 grad_map = train_net.AddGradientOperators(loss.field_blobs())
67 model.apply_optimizers(train_net, train_init_net, grad_map)
68 return train_init_net, train_net