3 from __future__
import absolute_import
4 from __future__
import division
5 from __future__
import print_function
6 from __future__
import unicode_literals
10 Regular, Recurrent = range(2)
16 return "{}/{}".format(str(scope), str(name))
20 def _calc_weighted_context(
22 encoder_outputs_transposed,
28 attention_weighted_encoder_context = model.net.BatchMatMul(
29 [encoder_outputs_transposed, attention_weights_3d],
30 s(scope,
'attention_weighted_encoder_context'),
34 attention_weighted_encoder_context, _ = model.net.Reshape(
35 attention_weighted_encoder_context,
37 attention_weighted_encoder_context,
38 s(scope,
'attention_weighted_encoder_context_old_shape')
40 shape=[1, -1, encoder_output_dim],
42 return attention_weighted_encoder_context
46 def _calc_attention_weights(
48 attention_logits_transposed,
54 attention_weights = model.Softmax(
55 attention_logits_transposed,
56 s(scope,
'attention_weights'),
61 attention_weights_3d = model.net.ExpandDims(
63 s(scope,
'attention_weights_3d'),
66 return attention_weights_3d
70 def _calc_attention_logits_from_sum_match(
72 decoder_hidden_encoder_outputs_sum,
77 decoder_hidden_encoder_outputs_sum = model.net.Tanh(
78 decoder_hidden_encoder_outputs_sum,
79 decoder_hidden_encoder_outputs_sum,
82 attention_v = model.param_init_net.XavierFill(
84 s(scope,
'attention_v'),
85 shape=[1, encoder_output_dim],
87 model.add_param(attention_v)
89 attention_zeros = model.param_init_net.ConstantFill(
91 s(scope,
'attention_zeros'),
97 attention_logits = model.net.FC(
98 [decoder_hidden_encoder_outputs_sum, attention_v, attention_zeros],
99 [s(scope,
'attention_logits')],
103 attention_logits = model.net.Squeeze(
109 attention_logits_transposed = model.Transpose(
111 s(scope,
'attention_logits_transposed'),
114 return attention_logits_transposed
118 def _apply_fc_weight_for_sum_match(
133 output = model.net.Squeeze(
142 def apply_recurrent_attention(
145 encoder_outputs_transposed,
146 weighted_encoder_outputs,
147 decoder_hidden_state_t,
148 decoder_hidden_state_dim,
149 attention_weighted_encoder_context_t_prev,
152 weighted_prev_attention_context = _apply_fc_weight_for_sum_match(
154 input=attention_weighted_encoder_context_t_prev,
155 dim_in=encoder_output_dim,
156 dim_out=encoder_output_dim,
158 name=
'weighted_prev_attention_context' 161 weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
163 input=decoder_hidden_state_t,
164 dim_in=decoder_hidden_state_dim,
165 dim_out=encoder_output_dim,
167 name=
'weighted_decoder_hidden_state' 171 decoder_hidden_encoder_outputs_sum_tmp = model.net.Add(
173 weighted_encoder_outputs,
174 weighted_decoder_hidden_state
176 s(scope,
'decoder_hidden_encoder_outputs_sum_tmp'),
181 decoder_hidden_encoder_outputs_sum = model.net.Add(
183 decoder_hidden_encoder_outputs_sum_tmp,
184 weighted_prev_attention_context
186 s(scope,
'decoder_hidden_encoder_outputs_sum'),
191 attention_logits_transposed = _calc_attention_logits_from_sum_match(
193 decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
194 encoder_output_dim=encoder_output_dim,
199 attention_weights_3d = _calc_attention_weights(
201 attention_logits_transposed=attention_logits_transposed,
206 attention_weighted_encoder_context = _calc_weighted_context(
208 encoder_outputs_transposed=encoder_outputs_transposed,
209 encoder_output_dim=encoder_output_dim,
210 attention_weights_3d=attention_weights_3d,
213 return attention_weighted_encoder_context, attention_weights_3d, [
214 decoder_hidden_encoder_outputs_sum_tmp,
215 decoder_hidden_encoder_outputs_sum
219 def apply_regular_attention(
222 encoder_outputs_transposed,
223 weighted_encoder_outputs,
224 decoder_hidden_state_t,
225 decoder_hidden_state_dim,
228 weighted_decoder_hidden_state = _apply_fc_weight_for_sum_match(
230 input=decoder_hidden_state_t,
231 dim_in=decoder_hidden_state_dim,
232 dim_out=encoder_output_dim,
234 name=
'weighted_decoder_hidden_state' 238 decoder_hidden_encoder_outputs_sum = model.net.Add(
239 [weighted_encoder_outputs, weighted_decoder_hidden_state],
240 s(scope,
'decoder_hidden_encoder_outputs_sum'),
245 attention_logits_transposed = _calc_attention_logits_from_sum_match(
247 decoder_hidden_encoder_outputs_sum=decoder_hidden_encoder_outputs_sum,
248 encoder_output_dim=encoder_output_dim,
253 attention_weights_3d = _calc_attention_weights(
255 attention_logits_transposed=attention_logits_transposed,
260 attention_weighted_encoder_context = _calc_weighted_context(
262 encoder_outputs_transposed=encoder_outputs_transposed,
263 encoder_output_dim=encoder_output_dim,
264 attention_weights_3d=attention_weights_3d,
267 return attention_weighted_encoder_context, attention_weights_3d, [
268 decoder_hidden_encoder_outputs_sum