3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
8 from caffe2.python
import scope, model_helpers
9 from caffe2.python.model_helper
import ModelHelperBase
10 from caffe2.proto
import caffe2_pb2
14 """A helper model so we can write CNN models more easily, without having to 15 manually define parameter initializations and operators separately. 18 def __init__(self, order="NCHW", name=None,
19 use_cudnn=True, cudnn_exhaustive_search=False,
20 ws_nbytes_limit=None, init_params=True,
21 skip_sparse_optim=False,
24 super(CNNModelHelper, self).__init__(
25 skip_sparse_optim=skip_sparse_optim,
26 name=
"CNN" if name
is None else name,
27 init_params=init_params,
28 param_model=param_model,
35 if self.
order !=
"NHWC" and self.
order !=
"NCHW":
37 "Cannot understand the CNN storage order %s." % self.
order 40 def GetWeights(self, namescope=None):
45 return self.weights[:]
47 return [w
for w
in self.weights
if w.GetNameScope() == namescope]
49 def GetBiases(self, namescope=None):
56 return [b
for b
in self.biases
if b.GetNameScope() == namescope]
59 self, blob_in, blob_out, use_gpu_transform=False, **kwargs
62 if self.
order ==
"NCHW":
63 if (use_gpu_transform):
64 kwargs[
'use_gpu_transform'] = 1
if use_gpu_transform
else 0
67 blob_in, [blob_out[0], blob_out[1]], **kwargs)
72 blob_in, [blob_out[0] +
'_nhwc', blob_out[1]], **kwargs)
73 data = self.net.NHWC2NCHW(data, blob_out[0])
76 blob_in, blob_out, **kwargs)
79 def PadImage(self, blob_in, blob_out, **kwargs):
80 self.net.PadImage(blob_in, blob_out, **kwargs)
82 def ConvNd(self, *args, **kwargs):
83 return model_helpers.ConvNd(self, *args, use_cudnn=self.
use_cudnn,
89 def Conv(self, *args, **kwargs):
90 return model_helpers.Conv(self, *args, use_cudnn=self.
use_cudnn,
96 def ConvTranspose(self, *args, **kwargs):
97 return model_helpers.ConvTranspose(self, *args, use_cudnn=self.
use_cudnn,
103 def GroupConv(self, *args, **kwargs):
104 return model_helpers.GroupConv(self, *args, use_cudnn=self.
use_cudnn,
110 def GroupConv_Deprecated(self, *args, **kwargs):
111 return model_helpers.GroupConv_Deprecated(self, *args, use_cudnn=self.
use_cudnn,
117 def FC(self, *args, **kwargs):
118 return model_helpers.FC(self, *args, **kwargs)
120 def PackedFC(self, *args, **kwargs):
121 return model_helpers.PackedFC(self, *args, **kwargs)
123 def FC_Prune(self, *args, **kwargs):
124 return model_helpers.FC_Prune(self, *args, **kwargs)
126 def FC_Decomp(self, *args, **kwargs):
127 return model_helpers.FC_Decomp(self, *args, **kwargs)
129 def FC_Sparse(self, *args, **kwargs):
130 return model_helpers.FC_Sparse(self, *args, **kwargs)
132 def Dropout(self, *args, **kwargs):
133 return model_helpers.Dropout(self, *args, **kwargs)
135 def LRN(self, *args, **kwargs):
136 return model_helpers.LRN(self, *args, **kwargs)
138 def Softmax(self, *args, **kwargs):
139 return model_helpers.Softmax(self, *args, use_cudnn=self.
use_cudnn,
142 def SpatialBN(self, *args, **kwargs):
143 return model_helpers.SpatialBN(self, *args, order=self.
order, **kwargs)
145 def InstanceNorm(self, *args, **kwargs):
146 return model_helpers.InstanceNorm(self, *args, order=self.
order,
149 def Relu(self, *args, **kwargs):
150 return model_helpers.Relu(self, *args, order=self.
order,
153 def PRelu(self, *args, **kwargs):
154 return model_helpers.PRelu(self, *args, **kwargs)
156 def Concat(self, *args, **kwargs):
157 return model_helpers.Concat(self, *args, order=self.
order, **kwargs)
160 """The old depth concat function - we should move to use concat.""" 161 print(
"DepthConcat is deprecated. use Concat instead.")
162 return self.
Concat(*args, **kwargs)
164 def Sum(self, *args, **kwargs):
165 return model_helpers.Sum(self, *args, **kwargs)
167 def Transpose(self, *args, **kwargs):
168 return model_helpers.Transpose(self, *args, use_cudnn=self.
use_cudnn,
171 def Iter(self, *args, **kwargs):
172 return model_helpers.Iter(self, *args, **kwargs)
174 def Accuracy(self, *args, **kwargs):
175 return model_helpers.Accuracy(self, *args, **kwargs)
177 def MaxPool(self, *args, **kwargs):
178 return model_helpers.MaxPool(self, *args, use_cudnn=self.
use_cudnn,
179 order=self.
order, **kwargs)
181 def AveragePool(self, *args, **kwargs):
182 return model_helpers.AveragePool(self, *args, use_cudnn=self.
use_cudnn,
183 order=self.
order, **kwargs)
186 def XavierInit(self):
187 return (
'XavierFill', {})
189 def ConstantInit(self, value):
190 return (
'ConstantFill', dict(value=value))
194 return (
'MSRAFill', {})
198 return (
'ConstantFill', {})
201 """Adds a decay to weights in the model. 203 This is a form of L2 regularization. 206 weight_decay: strength of the regularization 208 if weight_decay <= 0.0:
210 wd = self.param_init_net.ConstantFill([],
'wd', shape=[1],
212 ONE = self.param_init_net.ConstantFill([],
"ONE", shape=[1], value=1.0)
215 grad = self.param_to_grad[param]
216 self.net.WeightedSum(
217 [grad, ONE, param, wd],
223 device_option = caffe2_pb2.DeviceOption()
224 device_option.device_type = caffe2_pb2.CPU
228 def GPU(self, gpu_id=0):
229 device_option = caffe2_pb2.DeviceOption()
230 device_option.device_type = caffe2_pb2.CUDA
231 device_option.cuda_gpu_id = gpu_id
def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, kwargs)
def Concat(self, args, kwargs)
def AddWeightDecay(self, weight_decay)
def GetWeights(self, namescope=None)
def DepthConcat(self, args, kwargs)