Caffe2 - Python API
A deep learning, cross platform ML framework
rnn_cell.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 import numpy as np
9 import random
10 
11 from caffe2.python.attention import (
12  AttentionType,
13  apply_regular_attention,
14  apply_recurrent_attention,
15 )
16 from caffe2.python import core, recurrent, workspace
17 from caffe2.python.cnn import CNNModelHelper
18 
19 
20 class RNNCell(object):
21  '''
22  Base class for writing recurrent / stateful operations.
23 
24  One needs to implement 3 methods: _apply, prepare_input and get_state_names.
25  As a result base class will provice apply_over_sequence method, which
26  allows you to apply recurrent operations over a sequence of any length.
27  '''
28  def __init__(self, name):
29  self.name = name
30  self.recompute_blobs = []
31 
32  def scope(self, name):
33  return self.name + '/' + name if self.name is not None else name
34 
35  def apply_over_sequence(
36  self,
37  model,
38  inputs,
39  seq_lengths,
40  initial_states,
41  outputs_with_grads=None,
42  ):
43  preprocessed_inputs = self.prepare_input(model, inputs)
44  step_model = CNNModelHelper(name=self.name, param_model=model)
45  input_t, timestep = step_model.net.AddScopedExternalInputs(
46  'input_t',
47  'timestep',
48  )
49  states_prev = step_model.net.AddScopedExternalInputs(*[
50  s + '_prev' for s in self.get_state_names()
51  ])
52  states = self._apply(
53  model=step_model,
54  input_t=input_t,
55  seq_lengths=seq_lengths,
56  states=states_prev,
57  timestep=timestep,
58  )
60  net=model.net,
61  cell_net=step_model.net,
62  inputs=[(input_t, preprocessed_inputs)],
63  initial_cell_inputs=zip(states_prev, initial_states),
64  links=dict(zip(states_prev, states)),
65  timestep=timestep,
66  scope=self.name,
67  outputs_with_grads=(
68  outputs_with_grads
69  if outputs_with_grads is not None
70  else self.get_outputs_with_grads()
71  ),
72  recompute_blobs_on_backward=self.recompute_blobs,
73  )
74 
75  def apply(self, model, input_t, seq_lengths, states, timestep):
76  input_t = self.prepare_input(model, input_t)
77  return self._apply(model, input_t, seq_lengths, states, timestep)
78 
79  def _apply(self, model, input_t, seq_lengths, states, timestep):
80  '''
81  A single step of a recurrent network.
82 
83  model: CNNModelHelper object new operators would be added to
84 
85  input_blob: single input with shape (1, batch_size, input_dim)
86 
87  seq_lengths: blob containing sequence lengths which would be passed to
88  LSTMUnit operator
89 
90  states: previous recurrent states
91 
92  timestep: current recurrent iteration. Could be used together with
93  seq_lengths in order to determine, if some shorter sequences
94  in the batch have already ended.
95  '''
96  raise NotImplementedError('Abstract method')
97 
98  def prepare_input(self, model, input_blob):
99  '''
100  If some operations in _apply method depend only on the input,
101  not on recurrent states, they could be computed in advance.
102 
103  model: CNNModelHelper object new operators would be added to
104 
105  input_blob: either the whole input sequence with shape
106  (sequence_length, batch_size, input_dim) or a single input with shape
107  (1, batch_size, input_dim).
108  '''
109  raise NotImplementedError('Abstract method')
110 
111  def get_state_names(self):
112  '''
113  Return the names of the recurrent states.
114  It's required by apply_over_sequence method in order to allocate
115  recurrent states for all steps with meaningful names.
116  '''
117  raise NotImplementedError('Abstract method')
118 
119 
121 
122  def __init__(
123  self,
124  input_size,
125  hidden_size,
126  forget_bias,
127  memory_optimization,
128  name,
129  ):
130  super(LSTMCell, self).__init__(name)
131  self.input_size = input_size
132  self.hidden_size = hidden_size
133  self.forget_bias = float(forget_bias)
134  self.memory_optimization = memory_optimization
135 
136  def _apply(
137  self,
138  model,
139  input_t,
140  seq_lengths,
141  states,
142  timestep,
143  ):
144  hidden_t_prev, cell_t_prev = states
145  gates_t = model.FC(
146  hidden_t_prev,
147  self.scope('gates_t'),
148  dim_in=self.hidden_size,
149  dim_out=4 * self.hidden_size,
150  axis=2,
151  )
152  model.net.Sum([gates_t, input_t], gates_t)
153  hidden_t, cell_t = model.net.LSTMUnit(
154  [
155  hidden_t_prev,
156  cell_t_prev,
157  gates_t,
158  seq_lengths,
159  timestep,
160  ],
161  list(self.get_state_names()),
162  forget_bias=self.forget_bias,
163  )
164  model.net.AddExternalOutputs(hidden_t, cell_t)
165  if self.memory_optimization:
166  self.recompute_blobs = [gates_t]
167  return hidden_t, cell_t
168 
169  def get_input_params(self):
170  return {
171  'weights': self.scope('i2h') + '_w',
172  'biases': self.scope('i2h') + '_b',
173  }
174 
175  def get_recurrent_params(self):
176  return {
177  'weights': self.scope('gates_t') + '_w',
178  'biases': self.scope('gates_t') + '_b',
179  }
180 
181  def prepare_input(self, model, input_blob):
182  return model.FC(
183  input_blob,
184  self.scope('i2h'),
185  dim_in=self.input_size,
186  dim_out=4 * self.hidden_size,
187  axis=2,
188  )
189 
190  def get_state_names(self):
191  return (self.scope('hidden_t'), self.scope('cell_t'))
192 
193  def get_outputs_with_grads(self):
194  return [0]
195 
196  def get_output_size(self):
197  return self.hidden_size
198 
199 
200 def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
201  scope, outputs_with_grads=(0,), return_params=False,
202  memory_optimization=False, forget_bias=0.0):
203  '''
204  Adds a standard LSTM recurrent network operator to a model.
205 
206  model: CNNModelHelper object new operators would be added to
207 
208  input_blob: the input sequence in a format T x N x D
209  where T is sequence size, N - batch size and D - input dimention
210 
211  seq_lengths: blob containing sequence lengths which would be passed to
212  LSTMUnit operator
213 
214  initial_states: a tupple of (hidden_input_blob, cell_input_blob)
215  which are going to be inputs to the cell net on the first iteration
216 
217  dim_in: input dimention
218 
219  dim_out: output dimention
220 
221  outputs_with_grads : position indices of output blobs which will receive
222  external error gradient during backpropagation
223 
224  return_params: if True, will return a dictionary of parameters of the LSTM
225 
226  memory_optimization: if enabled, the LSTM step is recomputed on backward step
227  so that we don't need to store forward activations for each
228  timestep. Saves memory with cost of computation.
229  '''
230  cell = LSTMCell(
231  input_size=dim_in,
232  hidden_size=dim_out,
233  forget_bias=forget_bias,
234  memory_optimization=memory_optimization,
235  name=scope,
236  )
237  result = cell.apply_over_sequence(
238  model=model,
239  inputs=input_blob,
240  seq_lengths=seq_lengths,
241  initial_states=initial_states,
242  outputs_with_grads=outputs_with_grads,
243  )
244  if return_params:
245  result = list(result) + [{
246  'input': cell.get_input_params(),
247  'recurrent': cell.get_recurrent_params(),
248  }]
249  return tuple(result)
250 
251 
252 def GetLSTMParamNames():
253  weight_params = ["input_gate_w", "forget_gate_w", "output_gate_w", "cell_w"]
254  bias_params = ["input_gate_b", "forget_gate_b", "output_gate_b", "cell_b"]
255  return {'weights': weight_params, 'biases': bias_params}
256 
257 
258 def InitFromLSTMParams(lstm_pblobs, param_values):
259  '''
260  Set the parameters of LSTM based on predefined values
261  '''
262  weight_params = GetLSTMParamNames()['weights']
263  bias_params = GetLSTMParamNames()['biases']
264  for input_type in param_values.keys():
265  weight_values = [param_values[input_type][w].flatten() for w in weight_params]
266  wmat = np.array([])
267  for w in weight_values:
268  wmat = np.append(wmat, w)
269  bias_values = [param_values[input_type][b].flatten() for b in bias_params]
270  bm = np.array([])
271  for b in bias_values:
272  bm = np.append(bm, b)
273 
274  weights_blob = lstm_pblobs[input_type]['weights']
275  bias_blob = lstm_pblobs[input_type]['biases']
276  cur_weight = workspace.FetchBlob(weights_blob)
277  cur_biases = workspace.FetchBlob(bias_blob)
278 
280  weights_blob,
281  wmat.reshape(cur_weight.shape).astype(np.float32))
283  bias_blob,
284  bm.reshape(cur_biases.shape).astype(np.float32))
285 
286 
287 def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out,
288  scope, recurrent_params=None, input_params=None,
289  num_layers=1, return_params=False):
290  '''
291  CuDNN version of LSTM for GPUs.
292  input_blob Blob containing the input. Will need to be available
293  when param_init_net is run, because the sequence lengths
294  and batch sizes will be inferred from the size of this
295  blob.
296  initial_states tuple of (hidden_init, cell_init) blobs
297  dim_in input dimensions
298  dim_out output/hidden dimension
299  scope namescope to apply
300  recurrent_params dict of blobs containing values for recurrent
301  gate weights, biases (if None, use random init values)
302  See GetLSTMParamNames() for format.
303  input_params dict of blobs containing values for input
304  gate weights, biases (if None, use random init values)
305  See GetLSTMParamNames() for format.
306  num_layers number of LSTM layers
307  return_params if True, returns (param_extract_net, param_mapping)
308  where param_extract_net is a net that when run, will
309  populate the blobs specified in param_mapping with the
310  current gate weights and biases (input/recurrent).
311  Useful for assigning the values back to non-cuDNN
312  LSTM.
313  '''
314  with core.NameScope(scope):
315  weight_params = GetLSTMParamNames()['weights']
316  bias_params = GetLSTMParamNames()['biases']
317 
318  input_weight_size = dim_out * dim_in
319  recurrent_weight_size = dim_out * dim_out
320  input_bias_size = dim_out
321  recurrent_bias_size = dim_out
322 
323  def init(layer, pname, input_type):
324  if pname in weight_params:
325  sz = input_weight_size if input_type == 'input' \
326  else recurrent_weight_size
327  elif pname in bias_params:
328  sz = input_bias_size if input_type == 'input' \
329  else recurrent_bias_size
330  else:
331  assert False, "unknown parameter type {}".format(pname)
332  return model.param_init_net.UniformFill(
333  [],
334  "lstm_init_{}_{}_{}".format(input_type, pname, layer),
335  shape=[sz])
336 
337  # Multiply by 4 since we have 4 gates per LSTM unit
338  total_sz = 4 * num_layers * (
339  input_weight_size + recurrent_weight_size + input_bias_size +
340  recurrent_bias_size
341  )
342 
343  weights = model.param_init_net.UniformFill(
344  [], "lstm_weight", shape=[total_sz])
345 
346  model.params.append(weights)
347  model.weights.append(weights)
348 
349  lstm_args = {
350  'hidden_size': dim_out,
351  'rnn_mode': 'lstm',
352  'bidirectional': 0, # TODO
353  'dropout': 1.0, # TODO
354  'input_mode': 'linear', # TODO
355  'num_layers': num_layers,
356  'engine': 'CUDNN'
357  }
358 
359  param_extract_net = core.Net("lstm_param_extractor")
360  param_extract_net.AddExternalInputs([input_blob, weights])
361  param_extract_mapping = {}
362 
363  # Populate the weights-blob from blobs containing parameters for
364  # the individual components of the LSTM, such as forget/input gate
365  # weights and bises. Also, create a special param_extract_net that
366  # can be used to grab those individual params from the black-box
367  # weights blob. These results can be then fed to InitFromLSTMParams()
368  for input_type in ['input', 'recurrent']:
369  param_extract_mapping[input_type] = {}
370  p = recurrent_params if input_type == 'recurrent' else input_params
371  if p is None:
372  p = {}
373  for pname in weight_params + bias_params:
374  for j in range(0, num_layers):
375  values = p[pname] if pname in p else init(j, pname, input_type)
376  model.param_init_net.RecurrentParamSet(
377  [input_blob, weights, values],
378  weights,
379  layer=j,
380  input_type=input_type,
381  param_type=pname,
382  **lstm_args
383  )
384  if pname not in param_extract_mapping[input_type]:
385  param_extract_mapping[input_type][pname] = {}
386  b = param_extract_net.RecurrentParamGet(
387  [input_blob, weights],
388  ["lstm_{}_{}_{}".format(input_type, pname, j)],
389  layer=j,
390  input_type=input_type,
391  param_type=pname,
392  **lstm_args
393  )
394  param_extract_mapping[input_type][pname][j] = b
395 
396  (hidden_input_blob, cell_input_blob) = initial_states
397  output, hidden_output, cell_output, rnn_scratch, dropout_states = \
398  model.net.Recurrent(
399  [input_blob, cell_input_blob, cell_input_blob, weights],
400  ["lstm_output", "lstm_hidden_output", "lstm_cell_output",
401  "lstm_rnn_scratch", "lstm_dropout_states"],
402  seed=random.randint(0, 100000), # TODO: dropout seed
403  **lstm_args
404  )
405  model.net.AddExternalOutputs(
406  hidden_output, cell_output, rnn_scratch, dropout_states)
407 
408  if return_params:
409  param_extract = param_extract_net, param_extract_mapping
410  return output, hidden_output, cell_output, param_extract
411  else:
412  return output, hidden_output, cell_output
413 
414 
416 
417  def __init__(
418  self,
419  encoder_output_dim,
420  encoder_outputs,
421  decoder_input_dim,
422  decoder_state_dim,
423  name,
424  attention_type,
425  weighted_encoder_outputs,
426  forget_bias,
427  lstm_memory_optimization,
428  attention_memory_optimization,
429  ):
430  super(LSTMWithAttentionCell, self).__init__(name)
431  self.encoder_output_dim = encoder_output_dim
432  self.encoder_outputs = encoder_outputs
433  self.decoder_input_dim = decoder_input_dim
434  self.decoder_state_dim = decoder_state_dim
435  self.weighted_encoder_outputs = weighted_encoder_outputs
436  self.encoder_outputs_transposed = None
437  assert attention_type in [
438  AttentionType.Regular,
439  AttentionType.Recurrent,
440  ]
441  self.attention_type = attention_type
442  self.lstm_memory_optimization = lstm_memory_optimization
443  self.attention_memory_optimization = attention_memory_optimization
444 
445  def _apply(
446  self,
447  model,
448  input_t,
449  seq_lengths,
450  states,
451  timestep,
452  ):
453  (
454  hidden_t_prev,
455  cell_t_prev,
456  attention_weighted_encoder_context_t_prev,
457  ) = states
458 
459  gates_concatenated_input_t, _ = model.net.Concat(
460  [hidden_t_prev, attention_weighted_encoder_context_t_prev],
461  [
462  self.scope('gates_concatenated_input_t'),
463  self.scope('_gates_concatenated_input_t_concat_dims'),
464  ],
465  axis=2,
466  )
467  gates_t = model.FC(
468  gates_concatenated_input_t,
469  self.scope('gates_t'),
470  dim_in=self.decoder_state_dim + self.encoder_output_dim,
471  dim_out=4 * self.decoder_state_dim,
472  axis=2,
473  )
474  model.net.Sum([gates_t, input_t], gates_t)
475 
476  hidden_t_intermediate, cell_t = model.net.LSTMUnit(
477  [
478  hidden_t_prev,
479  cell_t_prev,
480  gates_t,
481  seq_lengths,
482  timestep,
483  ],
484  ['hidden_t_intermediate', self.scope('cell_t')],
485  )
486  if self.attention_type == AttentionType.Recurrent:
487  (
488  attention_weighted_encoder_context_t,
489  self.attention_weights_3d,
490  attention_blobs,
491  ) = apply_recurrent_attention(
492  model=model,
493  encoder_output_dim=self.encoder_output_dim,
494  encoder_outputs_transposed=self.encoder_outputs_transposed,
495  weighted_encoder_outputs=self.weighted_encoder_outputs,
496  decoder_hidden_state_t=hidden_t_intermediate,
497  decoder_hidden_state_dim=self.decoder_state_dim,
498  scope=self.name,
499  attention_weighted_encoder_context_t_prev=(
500  attention_weighted_encoder_context_t_prev
501  ),
502  )
503  else:
504  (
505  attention_weighted_encoder_context_t,
506  self.attention_weights_3d,
507  attention_blobs,
508  ) = apply_regular_attention(
509  model=model,
510  encoder_output_dim=self.encoder_output_dim,
511  encoder_outputs_transposed=self.encoder_outputs_transposed,
512  weighted_encoder_outputs=self.weighted_encoder_outputs,
513  decoder_hidden_state_t=hidden_t_intermediate,
514  decoder_hidden_state_dim=self.decoder_state_dim,
515  scope=self.name,
516  )
517  hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
518  model.net.AddExternalOutputs(
519  cell_t,
520  hidden_t,
521  attention_weighted_encoder_context_t,
522  )
524  self.recompute_blobs.extend(attention_blobs)
525  if self.lstm_memory_optimization:
526  self.recompute_blobs.append(gates_t)
527 
528  return hidden_t, cell_t, attention_weighted_encoder_context_t
529 
530  def get_attention_weights(self):
531  # [batch_size, encoder_length, 1]
532  return self.attention_weights_3d
533 
534  def prepare_input(self, model, input_blob):
535  if self.encoder_outputs_transposed is None:
536  self.encoder_outputs_transposed = model.Transpose(
537  self.encoder_outputs,
538  self.scope('encoder_outputs_transposed'),
539  axes=[1, 2, 0],
540  )
541  if self.weighted_encoder_outputs is None:
542  self.weighted_encoder_outputs = model.FC(
543  self.encoder_outputs,
544  self.scope('weighted_encoder_outputs'),
545  dim_in=self.encoder_output_dim,
546  dim_out=self.encoder_output_dim,
547  axis=2,
548  )
549 
550  return model.FC(
551  input_blob,
552  self.scope('i2h'),
553  dim_in=self.decoder_input_dim,
554  dim_out=4 * self.decoder_state_dim,
555  axis=2,
556  )
557 
558  def get_state_names(self):
559  return (
560  self.scope('hidden_t'),
561  self.scope('cell_t'),
562  self.scope('attention_weighted_encoder_context_t'),
563  )
564 
565  def get_outputs_with_grads(self):
566  return [0, 4]
567 
568  def get_output_size(self):
569  return self.decoder_state_dim + self.encoder_output_dim
570 
571 
573  model,
574  decoder_inputs,
575  decoder_input_lengths,
576  initial_decoder_hidden_state,
577  initial_decoder_cell_state,
578  initial_attention_weighted_encoder_context,
579  encoder_output_dim,
580  encoder_outputs,
581  decoder_input_dim,
582  decoder_state_dim,
583  scope,
584  attention_type=AttentionType.Regular,
585  outputs_with_grads=(0, 4),
586  weighted_encoder_outputs=None,
587  lstm_memory_optimization=False,
588  attention_memory_optimization=False,
589  forget_bias=0.0,
590 ):
591  '''
592  Adds a LSTM with attention mechanism to a model.
593 
594  The implementation is based on https://arxiv.org/abs/1409.0473, with
595  a small difference in the order
596  how we compute new attention context and new hidden state, similarly to
597  https://arxiv.org/abs/1508.04025.
598 
599  The model uses encoder-decoder naming conventions,
600  where the decoder is the sequence the op is iterating over,
601  while computing the attention context over the encoder.
602 
603  model: CNNModelHelper object new operators would be added to
604 
605  decoder_inputs: the input sequence in a format T x N x D
606  where T is sequence size, N - batch size and D - input dimention
607 
608  decoder_input_lengths: blob containing sequence lengths
609  which would be passed to LSTMUnit operator
610 
611  initial_decoder_hidden_state: initial hidden state of LSTM
612 
613  initial_decoder_cell_state: initial cell state of LSTM
614 
615  initial_attention_weighted_encoder_context: initial attention context
616 
617  encoder_output_dim: dimension of encoder outputs
618 
619  encoder_outputs: the sequence, on which we compute the attention context
620  at every iteration
621 
622  decoder_input_dim: input dimention (last dimension on decoder_inputs)
623 
624  decoder_state_dim: size of hidden states of LSTM
625 
626  attention_type: One of: AttentionType.Regular, AttentionType.Recurrent.
627  Determines which type of attention mechanism to use.
628 
629  outputs_with_grads : position indices of output blobs which will receive
630  external error gradient during backpropagation
631 
632  weighted_encoder_outputs: encoder outputs to be used to compute attention
633  weights. In the basic case it's just linear transformation of
634  encoder outputs (that the default, when weighted_encoder_outputs is None).
635  However, it can be something more complicated - like a separate
636  encoder network (for example, in case of convolutional encoder)
637 
638  lstm_memory_optimization: recompute LSTM activations on backward pass, so
639  we don't need to store their values in forward passes
640 
641  attention_memory_optimization: recompute attention for backward pass
642  '''
643  cell = LSTMWithAttentionCell(
644  encoder_output_dim=encoder_output_dim,
645  encoder_outputs=encoder_outputs,
646  decoder_input_dim=decoder_input_dim,
647  decoder_state_dim=decoder_state_dim,
648  name=scope,
649  attention_type=attention_type,
650  weighted_encoder_outputs=weighted_encoder_outputs,
651  forget_bias=forget_bias,
652  lstm_memory_optimization=lstm_memory_optimization,
653  attention_memory_optimization=attention_memory_optimization,
654  )
655  return cell.apply_over_sequence(
656  model=model,
657  inputs=decoder_inputs,
658  seq_lengths=decoder_input_lengths,
659  initial_states=(
660  initial_decoder_hidden_state,
661  initial_decoder_cell_state,
662  initial_attention_weighted_encoder_context,
663  ),
664  outputs_with_grads=None,
665  )
666 
667 
669 
670  def _apply(
671  self,
672  model,
673  input_t,
674  seq_lengths,
675  states,
676  timestep,
677  ):
678  (
679  hidden_t_prev,
680  cell_t_prev,
681  ) = states
682 
683  # hU^T
684  # Shape: [1, batch_size, 4 * hidden_size]
685  prev_t = model.FC(
686  hidden_t_prev, self.scope('prev_t'), dim_in=self.hidden_size,
687  dim_out=4 * self.hidden_size, axis=2)
688  # defining MI parameters
689  alpha = model.param_init_net.ConstantFill(
690  [],
691  [self.scope('alpha')],
692  shape=[4 * self.hidden_size],
693  value=1.0
694  )
695  beta1 = model.param_init_net.ConstantFill(
696  [],
697  [self.scope('beta1')],
698  shape=[4 * self.hidden_size],
699  value=1.0
700  )
701  beta2 = model.param_init_net.ConstantFill(
702  [],
703  [self.scope('beta2')],
704  shape=[4 * self.hidden_size],
705  value=1.0
706  )
707  b = model.param_init_net.ConstantFill(
708  [],
709  [self.scope('b')],
710  shape=[4 * self.hidden_size],
711  value=0.0
712  )
713  model.params.extend([alpha, beta1, beta2, b])
714  # alpha * (xW^T * hU^T)
715  # Shape: [1, batch_size, 4 * hidden_size]
716  alpha_tdash = model.net.Mul(
717  [prev_t, input_t],
718  self.scope('alpha_tdash')
719  )
720  # Shape: [batch_size, 4 * hidden_size]
721  alpha_tdash_rs, _ = model.net.Reshape(
722  alpha_tdash,
723  [self.scope('alpha_tdash_rs'), self.scope('alpha_tdash_old_shape')],
724  shape=[-1, 4 * self.hidden_size],
725  )
726  alpha_t = model.net.Mul(
727  [alpha_tdash_rs, alpha],
728  self.scope('alpha_t'),
729  broadcast=1,
730  use_grad_hack=1
731  )
732  # beta1 * hU^T
733  # Shape: [batch_size, 4 * hidden_size]
734  prev_t_rs, _ = model.net.Reshape(
735  prev_t,
736  [self.scope('prev_t_rs'), self.scope('prev_t_old_shape')],
737  shape=[-1, 4 * self.hidden_size],
738  )
739  beta1_t = model.net.Mul(
740  [prev_t_rs, beta1],
741  self.scope('beta1_t'),
742  broadcast=1,
743  use_grad_hack=1
744  )
745  # beta2 * xW^T
746  # Shape: [batch_szie, 4 * hidden_size]
747  input_t_rs, _ = model.net.Reshape(
748  input_t,
749  [self.scope('input_t_rs'), self.scope('input_t_old_shape')],
750  shape=[-1, 4 * self.hidden_size],
751  )
752  beta2_t = model.net.Mul(
753  [input_t_rs, beta2],
754  self.scope('beta2_t'),
755  broadcast=1,
756  use_grad_hack=1
757  )
758  # Add 'em all up
759  gates_tdash = model.net.Sum(
760  [alpha_t, beta1_t, beta2_t],
761  self.scope('gates_tdash')
762  )
763  gates_t = model.net.Add(
764  [gates_tdash, b],
765  self.scope('gates_t'),
766  broadcast=1,
767  use_grad_hack=1
768  )
769  # # Shape: [1, batch_size, 4 * hidden_size]
770  gates_t_rs, _ = model.net.Reshape(
771  gates_t,
772  [self.scope('gates_t_rs'), self.scope('gates_t_old_shape')],
773  shape=[1, -1, 4 * self.hidden_size],
774  )
775 
776  hidden_t_intermediate, cell_t = model.net.LSTMUnit(
777  [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
778  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
779  forget_bias=self.forget_bias,
780  )
781  hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
782  model.net.AddExternalOutputs(
783  cell_t,
784  hidden_t,
785  )
786  if self.memory_optimization:
787  self.recompute_blobs = [gates_t]
788  return hidden_t, cell_t
789 
790 
791 def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out,
792  scope, outputs_with_grads=(0,), memory_optimization=False,
793  forget_bias=0.0):
794  '''
795  Adds MI flavor of standard LSTM recurrent network operator to a model.
796  See https://arxiv.org/pdf/1606.06630.pdf
797 
798  model: CNNModelHelper object new operators would be added to
799 
800  input_blob: the input sequence in a format T x N x D
801  where T is sequence size, N - batch size and D - input dimention
802 
803  seq_lengths: blob containing sequence lengths which would be passed to
804  LSTMUnit operator
805 
806  initial_states: a tupple of (hidden_input_blob, cell_input_blob)
807  which are going to be inputs to the cell net on the first iteration
808 
809  dim_in: input dimention
810 
811  dim_out: output dimention
812 
813  outputs_with_grads : position indices of output blobs which will receive
814  external error gradient during backpropagation
815 
816  memory_optimization: if enabled, the LSTM step is recomputed on backward step
817  so that we don't need to store forward activations for each
818  timestep. Saves memory with cost of computation.
819  '''
820  cell = MILSTMCell(
821  input_size=dim_in,
822  hidden_size=dim_out,
823  forget_bias=forget_bias,
824  memory_optimization=memory_optimization,
825  name=scope,
826  )
827  result = cell.apply_over_sequence(
828  model=model,
829  inputs=input_blob,
830  seq_lengths=seq_lengths,
831  initial_states=initial_states,
832  outputs_with_grads=outputs_with_grads,
833  )
834  return tuple(result)
835 
836 
838 
839  def _apply(
840  self,
841  model,
842  input_t,
843  seq_lengths,
844  states,
845  timestep,
846  ):
847  (
848  hidden_t_prev,
849  cell_t_prev,
850  attention_weighted_encoder_context_t_prev,
851  ) = states
852 
853  gates_concatenated_input_t, _ = model.net.Concat(
854  [hidden_t_prev, attention_weighted_encoder_context_t_prev],
855  [
856  self.scope('gates_concatenated_input_t'),
857  self.scope('_gates_concatenated_input_t_concat_dims'),
858  ],
859  axis=2,
860  )
861  # hU^T
862  # Shape: [1, batch_size, 4 * hidden_size]
863  prev_t = model.FC(
864  gates_concatenated_input_t,
865  self.scope('prev_t'),
866  dim_in=self.decoder_state_dim + self.encoder_output_dim,
867  dim_out=4 * self.decoder_state_dim,
868  axis=2,
869  )
870  # defining MI parameters
871  alpha = model.param_init_net.ConstantFill(
872  [],
873  [self.scope('alpha')],
874  shape=[4 * self.decoder_state_dim],
875  value=1.0
876  )
877  beta1 = model.param_init_net.ConstantFill(
878  [],
879  [self.scope('beta1')],
880  shape=[4 * self.decoder_state_dim],
881  value=1.0
882  )
883  beta2 = model.param_init_net.ConstantFill(
884  [],
885  [self.scope('beta2')],
886  shape=[4 * self.decoder_state_dim],
887  value=1.0
888  )
889  b = model.param_init_net.ConstantFill(
890  [],
891  [self.scope('b')],
892  shape=[4 * self.decoder_state_dim],
893  value=0.0
894  )
895  model.params.extend([alpha, beta1, beta2, b])
896  # alpha * (xW^T * hU^T)
897  # Shape: [1, batch_size, 4 * hidden_size]
898  alpha_tdash = model.net.Mul(
899  [prev_t, input_t],
900  self.scope('alpha_tdash')
901  )
902  # Shape: [batch_size, 4 * hidden_size]
903  alpha_tdash_rs, _ = model.net.Reshape(
904  alpha_tdash,
905  [self.scope('alpha_tdash_rs'), self.scope('alpha_tdash_old_shape')],
906  shape=[-1, 4 * self.decoder_state_dim],
907  )
908  alpha_t = model.net.Mul(
909  [alpha_tdash_rs, alpha],
910  self.scope('alpha_t'),
911  broadcast=1,
912  use_grad_hack=1
913  )
914  # beta1 * hU^T
915  # Shape: [batch_size, 4 * hidden_size]
916  prev_t_rs, _ = model.net.Reshape(
917  prev_t,
918  [self.scope('prev_t_rs'), self.scope('prev_t_old_shape')],
919  shape=[-1, 4 * self.decoder_state_dim],
920  )
921  beta1_t = model.net.Mul(
922  [prev_t_rs, beta1],
923  self.scope('beta1_t'),
924  broadcast=1,
925  use_grad_hack=1
926  )
927  # beta2 * xW^T
928  # Shape: [batch_szie, 4 * hidden_size]
929  input_t_rs, _ = model.net.Reshape(
930  input_t,
931  [self.scope('input_t_rs'), self.scope('input_t_old_shape')],
932  shape=[-1, 4 * self.decoder_state_dim],
933  )
934  beta2_t = model.net.Mul(
935  [input_t_rs, beta2],
936  self.scope('beta2_t'),
937  broadcast=1,
938  use_grad_hack=1
939  )
940  # Add 'em all up
941  gates_tdash = model.net.Sum(
942  [alpha_t, beta1_t, beta2_t],
943  self.scope('gates_tdash')
944  )
945  gates_t = model.net.Add(
946  [gates_tdash, b],
947  self.scope('gates_t'),
948  broadcast=1,
949  use_grad_hack=1
950  )
951  # # Shape: [1, batch_size, 4 * hidden_size]
952  gates_t_rs, _ = model.net.Reshape(
953  gates_t,
954  [self.scope('gates_t_rs'), self.scope('gates_t_old_shape')],
955  shape=[1, -1, 4 * self.decoder_state_dim],
956  )
957 
958  hidden_t_intermediate, cell_t = model.net.LSTMUnit(
959  [hidden_t_prev, cell_t_prev, gates_t_rs, seq_lengths, timestep],
960  [self.scope('hidden_t_intermediate'), self.scope('cell_t')],
961  )
962 
963  if self.attention_type == AttentionType.Recurrent:
964  (
965  attention_weighted_encoder_context_t,
966  self.attention_weights_3d,
967  self.recompute_blobs,
968  ) = (
969  apply_recurrent_attention(
970  model=model,
971  encoder_output_dim=self.encoder_output_dim,
972  encoder_outputs_transposed=self.encoder_outputs_transposed,
973  weighted_encoder_outputs=self.weighted_encoder_outputs,
974  decoder_hidden_state_t=hidden_t_intermediate,
975  decoder_hidden_state_dim=self.decoder_state_dim,
976  scope=self.name,
977  attention_weighted_encoder_context_t_prev=(
978  attention_weighted_encoder_context_t_prev
979  ),
980  )
981  )
982  else:
983  (
984  attention_weighted_encoder_context_t,
985  self.attention_weights_3d,
986  self.recompute_blobs,
987  ) = (
988  apply_regular_attention(
989  model=model,
990  encoder_output_dim=self.encoder_output_dim,
991  encoder_outputs_transposed=self.encoder_outputs_transposed,
992  weighted_encoder_outputs=self.weighted_encoder_outputs,
993  decoder_hidden_state_t=hidden_t_intermediate,
994  decoder_hidden_state_dim=self.decoder_state_dim,
995  scope=self.name,
996  )
997  )
998  hidden_t = model.Copy(hidden_t_intermediate, self.scope('hidden_t'))
999  model.net.AddExternalOutputs(
1000  cell_t,
1001  hidden_t,
1002  attention_weighted_encoder_context_t,
1003  )
1004  return hidden_t, cell_t, attention_weighted_encoder_context_t
def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False, memory_optimization=False, forget_bias=0.0)
Definition: rnn_cell.py:202
def InitFromLSTMParams(lstm_pblobs, param_values)
Definition: rnn_cell.py:258
Module caffe2.python.scope.
Definition: scope.py:1
def MILSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), memory_optimization=False, forget_bias=0.0)
Definition: rnn_cell.py:793
def LSTMWithAttention(model, decoder_inputs, decoder_input_lengths, initial_decoder_hidden_state, initial_decoder_cell_state, initial_attention_weighted_encoder_context, encoder_output_dim, encoder_outputs, decoder_input_dim, decoder_state_dim, scope, attention_type=AttentionType.Regular, outputs_with_grads=(0, 4), weighted_encoder_outputs=None, lstm_memory_optimization=False, attention_memory_optimization=False, forget_bias=0.0)
Definition: rnn_cell.py:590
def cudnn_LSTM(model, input_blob, initial_states, dim_in, dim_out, scope, recurrent_params=None, input_params=None, num_layers=1, return_params=False)
Definition: rnn_cell.py:289
NameScope
Definition: core.py:28
def FeedBlob(name, arr, device_option=None)
Definition: workspace.py:229
def prepare_input(self, model, input_blob)
Definition: rnn_cell.py:98
def scope(self, name)
Definition: rnn_cell.py:32
def _apply(self, model, input_t, seq_lengths, states, timestep)
Definition: rnn_cell.py:79
def recurrent_net(net, cell_net, inputs, initial_cell_inputs, links, timestep=None, scope=None, outputs_with_grads=(0,), recompute_blobs_on_backward=None)
Definition: recurrent.py:18
def FetchBlob(name)
Definition: workspace.py:276
def get_state_names(self)
Definition: rnn_cell.py:111