3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
7 from caffe2.python
import core, recurrent
8 from caffe2.python.cnn
import CNNModelHelper
12 Due to a limitation in ReccurentNetworkOp, this layer only supports batch_size=1 13 In order to support batch_size > 1, we will have to implement the CRFUnit 14 and its gradient in C++ and handle the different batches there. 19 def __init__(self, model, num_classes, transitions_blob=None):
23 if not transitions_blob:
24 transitions_blob = self.
model.param_init_net.UniformFill(
34 def crf_loss(self, predictions, labels, seq_lengths=None):
38 transitions_snapshot = self.
model.net.Copy(
50 labels, transitions_snapshot, seq_lengths
52 path_total_score = self.
model.net.Add(
53 [path_binary_score, path_unary_score],
57 zero_index = self.
model.param_init_net.ConstantFill(
58 [], shape=[1], value=0
60 initial_state = self.
model.net.Gather(
61 [predictions, zero_index],
65 input_data, _ = self.
model.net.RemovePadding(
71 input_data = self.
model.net.ExpandDims(
78 transitions_copy = self.
model.net.Copy(
82 input_data, initial_state, transitions_copy
84 loss = self.
model.net.Sub(
85 [all_paths_scores, path_total_score],
90 def _pad_predictions(self, predictions):
104 b_scores = self.
model.param_init_net.GivenTensorFill(
107 e_scores = self.
model.param_init_net.GivenTensorFill(
111 zero_index = self.
model.param_init_net.ConstantFill(
112 [], shape=[1, ], value=0
114 length = self.
model.net.Gather(
115 [self.
model.net.Shape([predictions]), zero_index],
117 length = self.
model.net.Cast(length, to=
'int32')
118 t_range = self.
model.net.LengthsRangeFill(length)
119 padding = self.
model.net.ConstantFill([t_range], value=low_score)
120 padding = self.
model.net.ExpandDims(padding, dims=[1])
121 padded_predictions, _ = self.
model.net.Concat(
122 [predictions, padding, padding],
126 padded_predictions_concat, _ = self.
model.net.Concat(
127 [b_scores, padded_predictions, e_scores],
131 return padded_predictions_concat
133 def _pad_labels(self, labels):
136 bos_i_b = self.
model.param_init_net.ConstantFill(
137 [], shape=[1], value=bos_i
139 eos_i_b = self.
model.param_init_net.ConstantFill(
140 [], shape=[1], value=eos_i
142 labels = self.
model.net.Cast([labels], to=
'int64')
143 padded_labels, _ = self.
model.net.Concat(
144 [bos_i_b, labels, eos_i_b],
150 def _path_binary_scores(self, labels, transitions, seq_lengths=None):
151 column_ids, _ = self.
model.net.RemovePadding(
157 row_ids, _ = self.
model.net.RemovePadding(
166 num_columns_blob = self.
model.net.ConstantFill(
170 flattened_ids = self.
model.net.Mul([row_ids, num_columns_blob])
171 flattened_ids = self.
model.net.Add([flattened_ids, column_ids])
172 flattened_transitions = self.
model.net.FlattenToVec([transitions])
173 entries = self.
model.net.Gather(
174 [flattened_transitions, flattened_ids],
177 return self.
model.ReduceFrontSum(entries)
179 def _gather_entries_sum(self, in_data, indices, index_size):
180 indices = self.
model.net.Cast([indices], to=
'int64')
181 index_size_blob = self.
model.param_init_net.ConstantFill(
186 query_one_hot = self.
model.net.OneHot(
187 [indices, index_size_blob]
189 flattend_query = self.
model.net.FlattenToVec(query_one_hot)
190 flattend_data = self.
model.net.FlattenToVec(in_data)
191 query_scores = self.
model.net.DotProduct(
192 [flattend_query, flattend_data]
194 final_sum = self.
model.net.ReduceFrontSum([query_scores])
206 input_blob, initial_state, transitions_copy
208 out_last, _ = self.
model.net.Reshape(
213 zero_segment_id = self.
model.param_init_net.ConstantFill(
217 dtype=core.DataType.INT32,
221 accum_score = self.
model.net.SortedSegmentRangeLogSumExp(
222 [out_last, zero_segment_id]
224 accum_score, _ = self.
model.net.Reshape(
233 Adds the crf_net recurrent operator to the model. 235 model: CNNModelHelper object new operators would be added to 237 input_blob: the input sequence in a format T x N x D 238 where T is sequence size, N - batch size and D - input dimention 239 ##Only supports batch-size 1## 241 seq_lengths: blob containing sequence lengths (unused) 250 return "{}/{}".format(str(scope), str(name))
252 step_model = CNNModelHelper(name=
'crf_step', param_model=self.
model)
253 input_t, cell_t_prev, _ = (
254 step_model.net.AddExternalInputs(
255 'input_t',
'cell_t_prev', transitions
258 zero_segment_id = step_model.param_init_net.ConstantFill(
260 [s(
'zero_segment_id')],
263 dtype=core.DataType.INT32,
267 step_model.param_init_net.AddExternalOutput(zero_segment_id)
270 prev_transpose = step_model.Transpose(
272 [s(
'prev_transpose')],
275 prev_tiled = step_model.net.Tile(
281 input_t_tiled = step_model.net.Tile(
283 [s(
'input_t_tiled')],
287 input_with_prev = step_model.net.Add(
288 [prev_tiled, input_t_tiled],
289 [s(
'input_with_prev')]
291 all_with_transitions = step_model.net.Add(
292 [input_with_prev, transitions],
293 [s(
'prev_with_transitions')],
297 all_with_transitions_reshaped, _ = step_model.net.Reshape(
298 all_with_transitions,
299 [s(
'all_with_transitions_reshaped'), s(
'all_with_transitions_orig')],
302 cell_t = step_model.net.SortedSegmentRangeLogSumExp(
303 [all_with_transitions_reshaped, zero_segment_id],
306 step_model.net.AddExternalOutputs(cell_t)
307 """ recurrent network """ 308 cell_input_blob = initial_state
311 cell_net=step_model.net,
312 inputs=[(input_t, input_blob)],
313 initial_cell_inputs=[
314 (cell_t_prev, cell_input_blob),
320 outputs_with_grads=(1,)
324 def update_predictions(self, classes):
326 def crf_update_predictions_op(inputs, outputs):
330 predictions = inputs[0].data
331 transitions = inputs[1].data
332 predictions = inputs[0].data
333 predictions_shape = inputs[0].shape
334 outputs[0].reshape(predictions_shape)
336 trellis = np.zeros(predictions_shape)
337 backpointers = np.zeros(predictions_shape, dtype=np.int32)
338 trellis[0] = predictions[0]
340 for t
in range(1, predictions_shape[0]):
341 v = np.expand_dims(trellis[t - 1], 1) + transitions
342 trellis[t] = predictions[t] + np.max(v, 0)
343 backpointers[t] = np.argmax(v, 0)
345 viterbi = [np.argmax(trellis[-1])]
346 for bp
in reversed(backpointers[1:]):
347 viterbi.append(bp[viterbi[-1]])
350 new_predictions = np.zeros(predictions_shape)
352 for i, w_predictions
in enumerate(predictions):
354 new_predictions[i] = predictions[i]
355 old_best = np.argmax(w_predictions)
356 old_bests.append(old_best)
359 w_predictions[viterbi[i]], w_predictions[old_best] = \
360 w_predictions[old_best], w_predictions[viterbi[i]]
361 new_predictions[i] = w_predictions
363 orig_predictions = new_predictions[1:-1, 0:-2]
364 outputs[0].reshape(orig_predictions.shape)
365 outputs[0].data[...] = orig_predictions
367 new_classes = self.
model.net.Python(crf_update_predictions_op)(
def _path_binary_scores(self, labels, transitions, seq_lengths=None)
def build_crf_net(self, input_blob, initial_state, transitions)
def _crf_forward(self, input_blob, initial_state, transitions_copy, seq_lengths=None)
def _pad_labels(self, labels)
def ScopedBlobReference(name, args, kwargs)
def _pad_predictions(self, predictions)
def _gather_entries_sum(self, in_data, indices, index_size)
def recurrent_net(net, cell_net, inputs, initial_cell_inputs, links, timestep=None, scope=None, outputs_with_grads=(0,), recompute_blobs_on_backward=None)