Caffe2 - Python API
A deep learning, cross platform ML framework
attention.py
1 
3 from __future__ import absolute_import
4 from __future__ import division
5 from __future__ import print_function
6 from __future__ import unicode_literals
7 
8 
10  Regular, Recurrent = range(2)
11 
12 
13 def s(scope, name):
14  # We have to manually scope due to our internal/external blob
15  # relationships.
16  return "{}/{}".format(str(scope), str(name))
17 
18 
19 # c_i = \sum_j w_{ij}\textbf{s}_j
20 def _calc_weighted_context(
21  model,
22  encoder_outputs_transposed,
23  encoder_output_dim,
24  attention_weights_3d,
25  scope,
26 ):
27  # [batch_size, encoder_output_dim, 1]
28  attention_weighted_encoder_context = model.net.BatchMatMul(
29  [encoder_outputs_transposed, attention_weights_3d],
30  s(scope, 'attention_weighted_encoder_context'),
31  )
32  # TODO: somehow I cannot use Squeeze in-place op here
33  # [batch_size, encoder_output_dim]
34  attention_weighted_encoder_context, _ = model.net.Reshape(
35  attention_weighted_encoder_context,
36  [
37  attention_weighted_encoder_context,
38  s(scope, 'attention_weighted_encoder_context_old_shape')
39  ],
40  shape=[1, -1, encoder_output_dim],
41  )
42  return attention_weighted_encoder_context
43 
44 
45 # Calculate a softmax over the passed in attention energy logits
46 def _calc_attention_weights(
47  model,
48  attention_logits_transposed,
49  scope
50 ):
51  # TODO: we could try to force some attention weights to be zeros,
52  # based on encoder_lengths.
53  # [batch_size, encoder_length]
54  attention_weights = model.Softmax(
55  attention_logits_transposed,
56  s(scope, 'attention_weights'),
57  engine='CUDNN',
58  )
59  # TODO: make this operation in-place
60  # [batch_size, encoder_length, 1]
61  attention_weights_3d = model.net.ExpandDims(
62  attention_weights,
63  s(scope, 'attention_weights_3d'),
64  dims=[2],
65  )
66  return attention_weights_3d
67 
68 
69 # e_{ij} = \textbf{v}^T tanh \alpha(\textbf{h}_{i-1}, \textbf{s}_j)
70 def _calc_attention_logits_from_sum_match(
71  model,
72  decoder_hidden_encoder_outputs_sum,
73  encoder_output_dim,
74  scope
75 ):
76  # [encoder_length, batch_size, encoder_output_dim]
77  decoder_hidden_encoder_outputs_sum = model.net.Tanh(
78  decoder_hidden_encoder_outputs_sum,
79  decoder_hidden_encoder_outputs_sum,
80  )
81 
82  attention_v = model.param_init_net.XavierFill(
83  [],
84  s(scope, 'attention_v'),
85  shape=[1, encoder_output_dim],
86  )
87  model.add_param(attention_v)
88 
89  attention_zeros = model.param_init_net.ConstantFill(
90  [],
91  s(scope, 'attention_zeros'),
92  value=0.0,
93  shape=[1],
94  )
95 
96  # [encoder_length, batch_size, 1]
97  attention_logits = model.net.FC(
98  [decoder_hidden_encoder_outputs_sum, attention_v, attention_zeros],
99  [s(scope, 'attention_logits')],
100  axis=2
101  )
102  # [encoder_length, batch_size]
103  attention_logits = model.net.Squeeze(
104  [attention_logits],
105  [attention_logits],
106  dims=[2],
107  )
108  # [batch_size, encoder_length]
109  attention_logits_transposed = model.Transpose(
110  attention_logits,
111  s(scope, 'attention_logits_transposed'),
112  axes=[1, 0],
113  )
114  return attention_logits_transposed
115 
116 
117 # \textbf{W}^\alpha used in the context of \alpha_{sum}(a,b)
118 def _apply_fc_weight_for_sum_match(
119  model,
120  input,
121  dim_in,
122  dim_out,
123  scope,
124  name
125 ):
126  output = model.FC(
127  input,
128  s(scope, name),
129  dim_in=dim_in,
130  dim_out=dim_out,
131  axis=2,
132  )
133  output = model.net.Squeeze(
134  output,
135  output,
136  dims=[0]
137  )
138  return output
139 
140 
141 # Implement RecAtt due to section 4.1 in http://arxiv.org/abs/1601.03317
142 def apply_recurrent_attention(
143  model,
144  encoder_output_dim,
145  encoder_outputs_transposed,
146  weighted_encoder_outputs,
147  decoder_hidden_state_t,
148  decoder_hidden_state_dim,
149  attention_weighted_encoder_context_t_prev,
150  scope,
151 ):
152  weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
153  model=model,
154  input=attention_weighted_encoder_context_t_prev,
155  dim_in=encoder_output_dim,
156  dim_out=encoder_output_dim,
157  scope=scope,
158  name='weighted_prev_attention_context'
159  )
160 
161  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
162  model=model,
163  input=decoder_hidden_state_t,
164  dim_in=decoder_hidden_state_dim,
165  dim_out=encoder_output_dim,
166  scope=scope,
167  name='weighted_decoder_hidden_state'
168  )
169 
170  # [encoder_length, batch_size, encoder_output_dim]
171  decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
172  [
173  weighted_encoder_outputs,
174  weighted_decoder_hidden_state
175  ],
176  s(scope, 'decoder_hidden_encoder_outputs_sum_tmp'),
177  broadcast=1,
178  use_grad_hack=1,
179  )
180  # [encoder_length, batch_size, encoder_output_dim]
181  decoder_hidden_encoder_outputs_sum = model.net.Add(
182  [
183  decoder_hidden_encoder_outputs_sum_tmp,
184  weighted_prev_attention_context
185  ],
186  s(scope, 'decoder_hidden_encoder_outputs_sum'),
187  broadcast=1,
188  use_grad_hack=1,
189  )
190 
191  attention_logits_transposed = _calc_attention_logits_from_sum_match(
192  model=model,
193  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
194  encoder_output_dim=encoder_output_dim,
195  scope=scope
196  )
197 
198  # [batch_size, encoder_length, 1]
199  attention_weights_3d = _calc_attention_weights(
200  model=model,
201  attention_logits_transposed=attention_logits_transposed,
202  scope=scope
203  )
204 
205  # [batch_size, encoder_output_dim, 1]
206  attention_weighted_encoder_context = _calc_weighted_context(
207  model=model,
208  encoder_outputs_transposed=encoder_outputs_transposed,
209  encoder_output_dim=encoder_output_dim,
210  attention_weights_3d=attention_weights_3d,
211  scope=scope
212  )
213  return attention_weighted_encoder_context, attention_weights_3d, [
214  decoder_hidden_encoder_outputs_sum_tmp,
215  decoder_hidden_encoder_outputs_sum
216  ]
217 
218 
219 def apply_regular_attention(
220  model,
221  encoder_output_dim,
222  encoder_outputs_transposed,
223  weighted_encoder_outputs,
224  decoder_hidden_state_t,
225  decoder_hidden_state_dim,
226  scope,
227 ):
228  weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
229  model=model,
230  input=decoder_hidden_state_t,
231  dim_in=decoder_hidden_state_dim,
232  dim_out=encoder_output_dim,
233  scope=scope,
234  name='weighted_decoder_hidden_state'
235  )
236 
237  # [encoder_length, batch_size, encoder_output_dim]
238  decoder_hidden_encoder_outputs_sum = model.net.Add(
239  [weighted_encoder_outputs, weighted_decoder_hidden_state],
240  s(scope, 'decoder_hidden_encoder_outputs_sum'),
241  broadcast=1,
242  use_grad_hack=1,
243  )
244 
245  attention_logits_transposed = _calc_attention_logits_from_sum_match(
246  model=model,
247  decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
248  encoder_output_dim=encoder_output_dim,
249  scope=scope
250  )
251 
252  # [batch_size, encoder_length, 1]
253  attention_weights_3d = _calc_attention_weights(
254  model=model,
255  attention_logits_transposed=attention_logits_transposed,
256  scope=scope
257  )
258 
259  # [batch_size, encoder_output_dim, 1]
260  attention_weighted_encoder_context = _calc_weighted_context(
261  model=model,
262  encoder_outputs_transposed=encoder_outputs_transposed,
263  encoder_output_dim=encoder_output_dim,
264  attention_weights_3d=attention_weights_3d,
265  scope=scope
266  )
267  return attention_weighted_encoder_context, attention_weights_3d, [
268  decoder_hidden_encoder_outputs_sum
269  ]