Caffe2 - Python API
A deep learning, cross platform ML framework
cnn.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 from caffe2.python import scope, model_helpers
9 from caffe2.python.model_helper import ModelHelperBase
10 from caffe2.proto import caffe2_pb2
11 
12 
13 class CNNModelHelper(ModelHelperBase):
14  """A helper model so we can write CNN models more easily, without having to
15  manually define parameter initializations and operators separately.
16  """
17 
18  def __init__(self, order="NCHW", name=None,
19  use_cudnn=True, cudnn_exhaustive_search=False,
20  ws_nbytes_limit=None, init_params=True,
21  skip_sparse_optim=False,
22  param_model=None):
23 
24  super(CNNModelHelper, self).__init__(
25  skip_sparse_optim=skip_sparse_optim,
26  name="CNN" if name is None else name,
27  init_params=init_params,
28  param_model=param_model,
29  )
30 
31  self.order = order
32  self.use_cudnn = use_cudnn
33  self.cudnn_exhaustive_search = cudnn_exhaustive_search
34  self.ws_nbytes_limit = ws_nbytes_limit
35  if self.order != "NHWC" and self.order != "NCHW":
36  raise ValueError(
37  "Cannot understand the CNN storage order %s." % self.order
38  )
39 
40  def GetWeights(self, namescope=None):
41  if namescope is None:
42  namescope = scope.CurrentNameScope()
43 
44  if namescope == '':
45  return self.weights[:]
46  else:
47  return [w for w in self.weights if w.GetNameScope() == namescope]
48 
49  def GetBiases(self, namescope=None):
50  if namescope is None:
51  namescope = scope.CurrentNameScope()
52 
53  if namescope == '':
54  return self.biases[:]
55  else:
56  return [b for b in self.biases if b.GetNameScope() == namescope]
57 
58  def ImageInput(
59  self, blob_in, blob_out, use_gpu_transform=False, **kwargs
60  ):
61  """Image Input."""
62  if self.order == "NCHW":
63  if (use_gpu_transform):
64  kwargs['use_gpu_transform'] = 1 if use_gpu_transform else 0
65  # GPU transform will handle NHWC -> NCHW
66  data, label = self.net.ImageInput(
67  blob_in, [blob_out[0], blob_out[1]], **kwargs)
68  # data = self.net.Transform(data, blob_out[0], **kwargs)
69  pass
70  else:
71  data, label = self.net.ImageInput(
72  blob_in, [blob_out[0] + '_nhwc', blob_out[1]], **kwargs)
73  data = self.net.NHWC2NCHW(data, blob_out[0])
74  else:
75  data, label = self.net.ImageInput(
76  blob_in, blob_out, **kwargs)
77  return data, label
78 
79  def PadImage(self, blob_in, blob_out, **kwargs):
80  self.net.PadImage(blob_in, blob_out, **kwargs)
81 
82  def ConvNd(self, *args, **kwargs):
83  return model_helpers.ConvNd(self, *args, use_cudnn=self.use_cudnn,
84  order=self.order,
85  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
86  ws_nbytes_limit=self.ws_nbytes_limit,
87  **kwargs)
88 
89  def Conv(self, *args, **kwargs):
90  return model_helpers.Conv(self, *args, use_cudnn=self.use_cudnn,
91  order=self.order,
92  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
93  ws_nbytes_limit=self.ws_nbytes_limit,
94  **kwargs)
95 
96  def ConvTranspose(self, *args, **kwargs):
97  return model_helpers.ConvTranspose(self, *args, use_cudnn=self.use_cudnn,
98  order=self.order,
99  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
100  ws_nbytes_limit=self.ws_nbytes_limit,
101  **kwargs)
102 
103  def GroupConv(self, *args, **kwargs):
104  return model_helpers.GroupConv(self, *args, use_cudnn=self.use_cudnn,
105  order=self.order,
106  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
107  ws_nbytes_limit=self.ws_nbytes_limit,
108  **kwargs)
109 
110  def GroupConv_Deprecated(self, *args, **kwargs):
111  return model_helpers.GroupConv_Deprecated(self, *args, use_cudnn=self.use_cudnn,
112  order=self.order,
113  cudnn_exhaustive_search=self.cudnn_exhaustive_search,
114  ws_nbytes_limit=self.ws_nbytes_limit,
115  **kwargs)
116 
117  def FC(self, *args, **kwargs):
118  return model_helpers.FC(self, *args, **kwargs)
119 
120  def PackedFC(self, *args, **kwargs):
121  return model_helpers.PackedFC(self, *args, **kwargs)
122 
123  def FC_Prune(self, *args, **kwargs):
124  return model_helpers.FC_Prune(self, *args, **kwargs)
125 
126  def FC_Decomp(self, *args, **kwargs):
127  return model_helpers.FC_Decomp(self, *args, **kwargs)
128 
129  def FC_Sparse(self, *args, **kwargs):
130  return model_helpers.FC_Sparse(self, *args, **kwargs)
131 
132  def Dropout(self, *args, **kwargs):
133  return model_helpers.Dropout(self, *args, **kwargs)
134 
135  def LRN(self, *args, **kwargs):
136  return model_helpers.LRN(self, *args, **kwargs)
137 
138  def Softmax(self, *args, **kwargs):
139  return model_helpers.Softmax(self, *args, use_cudnn=self.use_cudnn,
140  **kwargs)
141 
142  def SpatialBN(self, *args, **kwargs):
143  return model_helpers.SpatialBN(self, *args, order=self.order, **kwargs)
144 
145  def InstanceNorm(self, *args, **kwargs):
146  return model_helpers.InstanceNorm(self, *args, order=self.order,
147  **kwargs)
148 
149  def Relu(self, *args, **kwargs):
150  return model_helpers.Relu(self, *args, order=self.order,
151  use_cudnn=self.use_cudnn, **kwargs)
152 
153  def PRelu(self, *args, **kwargs):
154  return model_helpers.PRelu(self, *args, **kwargs)
155 
156  def Concat(self, *args, **kwargs):
157  return model_helpers.Concat(self, *args, order=self.order, **kwargs)
158 
159  def DepthConcat(self, *args, **kwargs):
160  """The old depth concat function - we should move to use concat."""
161  print("DepthConcat is deprecated. use Concat instead.")
162  return self.Concat(*args, **kwargs)
163 
164  def Sum(self, *args, **kwargs):
165  return model_helpers.Sum(self, *args, **kwargs)
166 
167  def Transpose(self, *args, **kwargs):
168  return model_helpers.Transpose(self, *args, use_cudnn=self.use_cudnn,
169  **kwargs)
170 
171  def Iter(self, *args, **kwargs):
172  return model_helpers.Iter(self, *args, **kwargs)
173 
174  def Accuracy(self, *args, **kwargs):
175  return model_helpers.Accuracy(self, *args, **kwargs)
176 
177  def MaxPool(self, *args, **kwargs):
178  return model_helpers.MaxPool(self, *args, use_cudnn=self.use_cudnn,
179  order=self.order, **kwargs)
180 
181  def AveragePool(self, *args, **kwargs):
182  return model_helpers.AveragePool(self, *args, use_cudnn=self.use_cudnn,
183  order=self.order, **kwargs)
184 
185  @property
186  def XavierInit(self):
187  return ('XavierFill', {})
188 
189  def ConstantInit(self, value):
190  return ('ConstantFill', dict(value=value))
191 
192  @property
193  def MSRAInit(self):
194  return ('MSRAFill', {})
195 
196  @property
197  def ZeroInit(self):
198  return ('ConstantFill', {})
199 
200  def AddWeightDecay(self, weight_decay):
201  """Adds a decay to weights in the model.
202 
203  This is a form of L2 regularization.
204 
205  Args:
206  weight_decay: strength of the regularization
207  """
208  if weight_decay <= 0.0:
209  return
210  wd = self.param_init_net.ConstantFill([], 'wd', shape=[1],
211  value=weight_decay)
212  ONE = self.param_init_net.ConstantFill([], "ONE", shape=[1], value=1.0)
213  for param in self.GetWeights():
214  # Equivalent to: grad += wd * param
215  grad = self.param_to_grad[param]
216  self.net.WeightedSum(
217  [grad, ONE, param, wd],
218  grad,
219  )
220 
221  @property
222  def CPU(self):
223  device_option = caffe2_pb2.DeviceOption()
224  device_option.device_type = caffe2_pb2.CPU
225  return device_option
226 
227  @property
228  def GPU(self, gpu_id=0):
229  device_option = caffe2_pb2.DeviceOption()
230  device_option.device_type = caffe2_pb2.CUDA
231  device_option.cuda_gpu_id = gpu_id
232  return device_option
def ImageInput(self, blob_in, blob_out, use_gpu_transform=False, kwargs)
Definition: cnn.py:60
def CurrentNameScope()
Definition: scope.py:26
cudnn_exhaustive_search
Definition: cnn.py:33
def Concat(self, args, kwargs)
Definition: cnn.py:156
def AddWeightDecay(self, weight_decay)
Definition: cnn.py:200
def GetWeights(self, namescope=None)
Definition: cnn.py:40
def DepthConcat(self, args, kwargs)
Definition: cnn.py:159