Caffe2 - Python API
A deep learning, cross platform ML framework
seq2seq_util.py
1 
3 """ A bunch of util functions to build Seq2Seq models with Caffe2."""
4 
5 from __future__ import absolute_import
6 from __future__ import division
7 from __future__ import print_function
8 from __future__ import unicode_literals
9 
10 from caffe2.python import rnn_cell
11 from caffe2.python.cnn import CNNModelHelper
12 
13 
14 class ModelHelper(CNNModelHelper):
15 
16  def __init__(self, init_params=True):
17  super(ModelHelper, self).__init__(
18  order='NCHW', # this is only relevant for convolutional networks
19  init_params=init_params,
20  )
21  self.non_trainable_params = []
22 
23  def AddParam(self, name, init=None, init_value=None, trainable=True):
24  """Adds a parameter to the model's net and it's initializer if needed
25 
26  Args:
27  init: a tuple (<initialization_op_name>, <initialization_op_kwargs>)
28  init_value: int, float or str. Can be used instead of `init` as a
29  simple constant initializer
30  trainable: bool, whether to compute gradient for this param or not
31  """
32  if init_value is not None:
33  assert init is None
34  assert type(init_value) in [int, float, str]
35  init = ('ConstantFill', dict(
36  shape=[1],
37  value=init_value,
38  ))
39 
40  if self.init_params:
41  param = self.param_init_net.__getattr__(init[0])(
42  [],
43  name,
44  **init[1]
45  )
46  else:
47  param = self.net.AddExternalInput(name)
48 
49  if trainable:
50  self.params.append(param)
51  else:
52  self.non_trainable_params.append(param)
53 
54  return param
55 
56 
58  model,
59  embedded_inputs,
60  input_lengths,
61  initial_hidden_state,
62  initial_cell_state,
63  embedding_size,
64  encoder_num_units,
65  use_attention
66 ):
67  """ Unidirectional (forward pass) LSTM encoder."""
68 
69  outputs, final_hidden_state, _, final_cell_state = rnn_cell.LSTM(
70  model=model,
71  input_blob=embedded_inputs,
72  seq_lengths=input_lengths,
73  initial_states=(initial_hidden_state, initial_cell_state),
74  dim_in=embedding_size,
75  dim_out=encoder_num_units,
76  scope='encoder',
77  outputs_with_grads=([0] if use_attention else [1, 3]),
78  )
79  return outputs, final_hidden_state, final_cell_state
80 
81 
83  model,
84  embedded_inputs,
85  input_lengths,
86  initial_hidden_state,
87  initial_cell_state,
88  embedding_size,
89  encoder_num_units,
90  use_attention
91 ):
92  """ Bidirectional (forward pass and backward pass) LSTM encoder."""
93 
94  # Forward pass
95  (
96  outputs_fw,
97  final_hidden_state_fw,
98  _,
99  final_cell_state_fw,
100  ) = rnn_cell.LSTM(
101  model=model,
102  input_blob=embedded_inputs,
103  seq_lengths=input_lengths,
104  initial_states=(initial_hidden_state, initial_cell_state),
105  dim_in=embedding_size,
106  dim_out=encoder_num_units,
107  scope='forward_encoder',
108  outputs_with_grads=([0] if use_attention else [1, 3]),
109  )
110 
111  # Backward pass
112  reversed_embedded_inputs = model.net.ReversePackedSegs(
113  [embedded_inputs, input_lengths],
114  ['reversed_embedded_inputs'],
115  )
116 
117  (
118  outputs_bw,
119  final_hidden_state_bw,
120  _,
121  final_cell_state_bw,
122  ) = rnn_cell.LSTM(
123  model=model,
124  input_blob=reversed_embedded_inputs,
125  seq_lengths=input_lengths,
126  initial_states=(initial_hidden_state, initial_cell_state),
127  dim_in=embedding_size,
128  dim_out=encoder_num_units,
129  scope='backward_encoder',
130  outputs_with_grads=([0] if use_attention else [1, 3]),
131  )
132 
133  outputs_bw = model.net.ReversePackedSegs(
134  [outputs_bw, input_lengths],
135  ['outputs_bw'],
136  )
137 
138  # Concatenate forward and backward results
139  outputs, _ = model.net.Concat(
140  [outputs_fw, outputs_bw],
141  ['outputs', 'outputs_dim'],
142  axis=2,
143  )
144 
145  final_hidden_state, _ = model.net.Concat(
146  [final_hidden_state_fw, final_hidden_state_bw],
147  ['final_hidden_state', 'final_hidden_state_dim'],
148  axis=2,
149  )
150 
151  final_cell_state, _ = model.net.Concat(
152  [final_cell_state_fw, final_cell_state_bw],
153  ['final_cell_state', 'final_cell_state_dim'],
154  axis=2,
155  )
156  return outputs, final_hidden_state, final_cell_state
def LSTM(model, input_blob, seq_lengths, initial_states, dim_in, dim_out, scope, outputs_with_grads=(0,), return_params=False, memory_optimization=False, forget_bias=0.0)
Definition: rnn_cell.py:202
def AddParam(self, name, init=None, init_value=None, trainable=True)
Definition: seq2seq_util.py:23
def rnn_unidirectional_encoder(model, embedded_inputs, input_lengths, initial_hidden_state, initial_cell_state, embedding_size, encoder_num_units, use_attention)
Definition: seq2seq_util.py:66
def rnn_bidirectional_encoder(model, embedded_inputs, input_lengths, initial_hidden_state, initial_cell_state, embedding_size, encoder_num_units, use_attention)
Definition: seq2seq_util.py:91