Caffe2 - Python API
A deep learning, cross platform ML framework
crf.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 from caffe2.python import core, recurrent
8 from caffe2.python.cnn import CNNModelHelper
9 import numpy as np
10 
11 '''
12 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1
13 In order to support batch_size > 1, we will have to implement the CRFUnit
14 and its gradient in C++ and handle the different batches there.
15 '''
16 
17 
18 class CRFWithLoss(object):
19  def __init__(self, model, num_classes, transitions_blob=None):
20  self.model = model
21  self.num_classes = num_classes
22  self.num_classes_padded = num_classes + 2 # After adding BOS and EOS
23  if not transitions_blob:
24  transitions_blob = self.model.param_init_net.UniformFill(
25  [],
26  [core.ScopedBlobReference('crf_transitions')],
27  shape=[self.num_classes_padded, self.num_classes_padded],
28  min=-1.0,
29  max=1.0
30  )
31  self.transitions = transitions_blob
32  self.model.params.append(self.transitions)
33 
34  def crf_loss(self, predictions, labels, seq_lengths=None):
35  # Since the transitions matrix is a shared parameter, need to
36  # take a snapshot of it at the beginning since it can be updated
37  # in between the operators that uses it when doing parallel updates
38  transitions_snapshot = self.model.net.Copy(
39  self.transitions, core.ScopedBlobReference('transitions_snapshot')
40  )
41  # Compute best path unary score from the logits
42  path_unary_score = self._gather_entries_sum(
43  predictions, labels, self.num_classes
44  )
45  # Append BOS and EOS entries to the predictions and labels
46  predictions = self._pad_predictions(predictions)
47  labels = self._pad_labels(labels)
48  # Compute best path binary scores from the transitions matrix
49  path_binary_score = self._path_binary_scores(
50  labels, transitions_snapshot, seq_lengths
51  )
52  path_total_score = self.model.net.Add(
53  [path_binary_score, path_unary_score],
54  core.ScopedBlobReference('path_total')
55  )
56  # Compute all paths score
57  zero_index = self.model.param_init_net.ConstantFill(
58  [], shape=[1], value=0
59  )
60  initial_state = self.model.net.Gather(
61  [predictions, zero_index],
62  core.ScopedBlobReference('rnn_initial'),
63  dense_gradient=True
64  )
65  input_data, _ = self.model.net.RemovePadding(
66  [predictions],
67  padding_width=1,
68  end_padding_width=0,
69  outputs=2,
70  )
71  input_data = self.model.net.ExpandDims(
72  [input_data],
73  core.ScopedBlobReference('rnn_input_data'),
74  dims=[1]
75  )
76  # Due to a bug in RecurrentNetworkGradientOp, we need to copy the
77  # transitions blob before sending it to the recurrent network
78  transitions_copy = self.model.net.Copy(
79  transitions_snapshot, core.ScopedBlobReference('transitions_copy')
80  )
81  all_paths_scores = self._crf_forward(
82  input_data, initial_state, transitions_copy
83  )
84  loss = self.model.net.Sub(
85  [all_paths_scores, path_total_score],
86  core.ScopedBlobReference('crf_loss')
87  )
88  return loss
89 
90  def _pad_predictions(self, predictions):
91  # This function will introduce two labels for beginning of sequence
92  # And end of sequence, it will make the necessary udpates to the
93  # the predictions blob
94 
95  low_score = -1000.0 # An arbitray very low number
96  b_scores = np.array(
97  [[low_score] * self.num_classes + [0, low_score]]
98  ).astype(np.float32)
99 
100  e_scores = np.array(
101  [[low_score] * self.num_classes + [low_score, 0]]
102  ).astype(np.float32)
103 
104  b_scores = self.model.param_init_net.GivenTensorFill(
105  [], "b_scores", shape=[1, self.num_classes_padded], values=b_scores
106  )
107  e_scores = self.model.param_init_net.GivenTensorFill(
108  [], "e_scores", shape=[1, self.num_classes_padded], values=e_scores
109  )
110 
111  zero_index = self.model.param_init_net.ConstantFill(
112  [], shape=[1, ], value=0
113  )
114  length = self.model.net.Gather(
115  [self.model.net.Shape([predictions]), zero_index],
116  )
117  length = self.model.net.Cast(length, to='int32')
118  t_range = self.model.net.LengthsRangeFill(length)
119  padding = self.model.net.ConstantFill([t_range], value=low_score)
120  padding = self.model.net.ExpandDims(padding, dims=[1])
121  padded_predictions, _ = self.model.net.Concat(
122  [predictions, padding, padding],
123  outputs=2,
124  axis=1
125  )
126  padded_predictions_concat, _ = self.model.net.Concat(
127  [b_scores, padded_predictions, e_scores],
128  outputs=2,
129  axis=0
130  )
131  return padded_predictions_concat
132 
133  def _pad_labels(self, labels):
134  bos_i = self.num_classes
135  eos_i = self.num_classes + 1
136  bos_i_b = self.model.param_init_net.ConstantFill(
137  [], shape=[1], value=bos_i
138  )
139  eos_i_b = self.model.param_init_net.ConstantFill(
140  [], shape=[1], value=eos_i
141  )
142  labels = self.model.net.Cast([labels], to='int64')
143  padded_labels, _ = self.model.net.Concat(
144  [bos_i_b, labels, eos_i_b],
145  axis=0,
146  outputs=2
147  )
148  return padded_labels
149 
150  def _path_binary_scores(self, labels, transitions, seq_lengths=None):
151  column_ids, _ = self.model.net.RemovePadding(
152  [labels],
153  outputs=2,
154  padding_width=1,
155  end_padding_width=0
156  )
157  row_ids, _ = self.model.net.RemovePadding(
158  [labels],
159  outputs=2,
160  padding_width=0,
161  end_padding_width=1
162  )
163  # Since there is no multi-dimensional gather, I flatten the matrix to
164  # a 1-d vector and transform the ids to (row_ids * num_columns +
165  # column_ids) and do gather in 1-d
166  num_columns_blob = self.model.net.ConstantFill(
167  [row_ids],
168  value=self.num_classes_padded,
169  )
170  flattened_ids = self.model.net.Mul([row_ids, num_columns_blob])
171  flattened_ids = self.model.net.Add([flattened_ids, column_ids])
172  flattened_transitions = self.model.net.FlattenToVec([transitions])
173  entries = self.model.net.Gather(
174  [flattened_transitions, flattened_ids],
175  dense_gradient=True
176  )
177  return self.model.ReduceFrontSum(entries)
178 
179  def _gather_entries_sum(self, in_data, indices, index_size):
180  indices = self.model.net.Cast([indices], to='int64')
181  index_size_blob = self.model.param_init_net.ConstantFill(
182  [],
183  shape=[1],
184  value=index_size,
185  )
186  query_one_hot = self.model.net.OneHot(
187  [indices, index_size_blob]
188  )
189  flattend_query = self.model.net.FlattenToVec(query_one_hot)
190  flattend_data = self.model.net.FlattenToVec(in_data)
191  query_scores = self.model.net.DotProduct(
192  [flattend_query, flattend_data]
193  )
194  final_sum = self.model.net.ReduceFrontSum([query_scores])
195  return final_sum
196 
197  def _crf_forward(
198  self,
199  input_blob,
200  initial_state,
201  transitions_copy,
202  seq_lengths=None
203  ):
204  # Build the RNN net and get the last timestep output
205  out_last = self.build_crf_net(
206  input_blob, initial_state, transitions_copy
207  )
208  out_last, _ = self.model.net.Reshape(
209  [out_last],
210  outputs=2,
211  shape=(self.num_classes_padded,)
212  )
213  zero_segment_id = self.model.param_init_net.ConstantFill(
214  [],
215  value=0,
216  shape=[self.num_classes_padded],
217  dtype=core.DataType.INT32,
218  )
219 
220  # Compute the accumlated total score of all the paths
221  accum_score = self.model.net.SortedSegmentRangeLogSumExp(
222  [out_last, zero_segment_id]
223  )
224  accum_score, _ = self.model.net.Reshape(
225  accum_score,
226  outputs=2,
227  shape=()
228  )
229  return accum_score
230 
231  def build_crf_net(self, input_blob, initial_state, transitions):
232  '''
233  Adds the crf_net recurrent operator to the model.
234 
235  model: CNNModelHelper object new operators would be added to
236 
237  input_blob: the input sequence in a format T x N x D
238  where T is sequence size, N - batch size and D - input dimention
239  ##Only supports batch-size 1##
240 
241  seq_lengths: blob containing sequence lengths (unused)
242  '''
243 
244  scope = 'crf_net'
245 
246  def s(name):
247  ''
248  # We have to manually scope due to our internal/external blob
249  # relationships.
250  return "{}/{}".format(str(scope), str(name))
251 
252  step_model = CNNModelHelper(name='crf_step', param_model=self.model)
253  input_t, cell_t_prev, _ = (
254  step_model.net.AddExternalInputs(
255  'input_t', 'cell_t_prev', transitions
256  )
257  )
258  zero_segment_id = step_model.param_init_net.ConstantFill(
259  [],
260  [s('zero_segment_id')],
261  value=0,
262  shape=[self.num_classes_padded],
263  dtype=core.DataType.INT32,
264  )
265 
266  # A hack to bypass model cloning for test
267  step_model.param_init_net.AddExternalOutput(zero_segment_id)
268  """ the CRF step """
269  # Do tile
270  prev_transpose = step_model.Transpose(
271  cell_t_prev,
272  [s('prev_transpose')],
273  axes=(0, 2, 1),
274  )
275  prev_tiled = step_model.net.Tile(
276  prev_transpose,
277  [s('prev_tiled')],
278  tiles=self.num_classes_padded,
279  axis=2,
280  )
281  input_t_tiled = step_model.net.Tile(
282  input_t,
283  [s('input_t_tiled')],
284  tiles=self.num_classes_padded,
285  axis=1,
286  )
287  input_with_prev = step_model.net.Add(
288  [prev_tiled, input_t_tiled],
289  [s('input_with_prev')]
290  )
291  all_with_transitions = step_model.net.Add(
292  [input_with_prev, transitions],
293  [s('prev_with_transitions')],
294  broadcast=1,
295  use_grad_hack=1,
296  )
297  all_with_transitions_reshaped, _ = step_model.net.Reshape(
298  all_with_transitions,
299  [s('all_with_transitions_reshaped'), s('all_with_transitions_orig')],
300  shape=(self.num_classes_padded, self.num_classes_padded)
301  )
302  cell_t = step_model.net.SortedSegmentRangeLogSumExp(
303  [all_with_transitions_reshaped, zero_segment_id],
304  [s('cell_t')],
305  )
306  step_model.net.AddExternalOutputs(cell_t)
307  """ recurrent network """
308  cell_input_blob = initial_state
309  out_all, out_last = recurrent.recurrent_net(
310  net=self.model.net,
311  cell_net=step_model.net,
312  inputs=[(input_t, input_blob)],
313  initial_cell_inputs=[
314  (cell_t_prev, cell_input_blob),
315  ],
316  links={
317  cell_t_prev: cell_t,
318  },
319  scope=scope,
320  outputs_with_grads=(1,)
321  )
322  return out_last
323 
324  def update_predictions(self, classes):
325 
326  def crf_update_predictions_op(inputs, outputs):
327  # This operator will compute the best path of classes by performing
328  # Viterbi decoding and then updates the predictions to make the tag
329  # On the best path has the highest score among the others
330  predictions = inputs[0].data
331  transitions = inputs[1].data
332  predictions = inputs[0].data
333  predictions_shape = inputs[0].shape
334  outputs[0].reshape(predictions_shape)
335 
336  trellis = np.zeros(predictions_shape)
337  backpointers = np.zeros(predictions_shape, dtype=np.int32)
338  trellis[0] = predictions[0]
339 
340  for t in range(1, predictions_shape[0]):
341  v = np.expand_dims(trellis[t - 1], 1) + transitions
342  trellis[t] = predictions[t] + np.max(v, 0)
343  backpointers[t] = np.argmax(v, 0)
344 
345  viterbi = [np.argmax(trellis[-1])]
346  for bp in reversed(backpointers[1:]):
347  viterbi.append(bp[viterbi[-1]])
348  viterbi.reverse()
349 
350  new_predictions = np.zeros(predictions_shape)
351  old_bests = []
352  for i, w_predictions in enumerate(predictions):
353  # Get the current tag with the maximum score
354  new_predictions[i] = predictions[i]
355  old_best = np.argmax(w_predictions)
356  old_bests.append(old_best)
357  # Swap the scores of the current best tag and the tag on the
358  # Viterbi path
359  w_predictions[viterbi[i]], w_predictions[old_best] = \
360  w_predictions[old_best], w_predictions[viterbi[i]]
361  new_predictions[i] = w_predictions
362  # Remove the BOS and EOS entries from the predictions matrix
363  orig_predictions = new_predictions[1:-1, 0:-2]
364  outputs[0].reshape(orig_predictions.shape)
365  outputs[0].data[...] = orig_predictions
366  padded_classes = self._pad_predictions(classes)
367  new_classes = self.model.net.Python(crf_update_predictions_op)(
368  [padded_classes, self.transitions],
369  core.ScopedBlobReference('post_crf_classes')
370  )
371  return new_classes
num_classes_padded
Definition: crf.py:22
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
Definition: crf.py:150
def build_crf_net(self, input_blob, initial_state, transitions)
Definition: crf.py:231
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
Definition: crf.py:203
def _pad_labels(self, labels)
Definition: crf.py:133
def ScopedBlobReference(name, args, kwargs)
Definition: core.py:212
def _pad_predictions(self, predictions)
Definition: crf.py:90
def _gather_entries_sum(self, in_data, indices, index_size)
Definition: crf.py:179
def recurrent_net(net, cell_net, inputs, initial_cell_inputs, links, timestep=None, scope=None, outputs_with_grads=(0,), recompute_blobs_on_backward=None)
Definition: recurrent.py:18