Module for constructing seq2seq models and dynamic decoding. Builds on top of libraries in tf.contrib.rnn
.
This library is composed of two primary components:
tf.contrib.rnn.RNNCell
objects.Attention wrappers are RNNCell
objects that wrap other RNNCell
objects and implement attention. The form of attention is determined by a subclass of tf.contrib.seq2seq.AttentionMechanism
. These subclasses describe the form of attention (e.g. additive vs. multiplicative) to use when creating the wrapper. An instance of an AttentionMechanism
is constructed with a memory
tensor, from which lookup keys and values tensors are created.
The two basic attention mechanisms are: tf.contrib.seq2seq.BahdanauAttention
(additive attention,
ref.) tf.contrib.seq2seq.LuongAttention
(multiplicative attention, ref.)
The memory
tensor passed the attention mechanism's constructor is expected to be shaped [batch_size, memory_max_time, memory_depth]
; and often an additional memory_sequence_length
vector is accepted. If provided, the memory
tensors' rows are masked with zeros past their true sequence lengths.
Attention mechanisms also have a concept of depth, usually determined as a construction parameter num_units
. For some kinds of attention (like BahdanauAttention
), both queries and memory are projected to tensors of depth num_units
. For other kinds (like LuongAttention
), num_units
should match the depth of the queries; and the memory
tensor will be projected to this depth.
The basic attention wrapper is tf.contrib.seq2seq.AttentionWrapper
. This wrapper accepts an RNNCell
instance, an instance of AttentionMechanism
, and an attention depth parameter (attention_size
); as well as several optional arguments that allow one to customize intermediate calculations.
At each time step, the basic calculation performed by this wrapper is:
cell_inputs = concat([inputs, prev_state.attention], -1) cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state) score = attention_mechanism(cell_output) alignments = softmax(score) context = matmul(alignments, attention_mechanism.values) attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1)) next_state = AttentionWrapperState( cell_state=next_cell_state, attention=attention) output = attention return output, next_state
In practice, a number of the intermediate calculations are configurable. For example, the initial concatenation of inputs
and prev_state.attention
can be replaced with another mixing function. The function softmax
can be replaced with alternative options when calculating alignments
from the score
. Finally, the outputs returned by the wrapper can be configured to be the value cell_output
instead of attention
.
The benefit of using a AttentionWrapper
is that it plays nicely with other wrappers and the dynamic decoder described below. For example, one can write:
cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0") attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs) attn_cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism, attention_size=256) attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1") top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1") multi_cell = MultiRNNCell([attn_cell, top_cell])
The multi_rnn
cell will perform the bottom layer calculations on GPU 0; attention calculations will be performed on GPU 1 and immediately passed up to the top layer which is also calculated on GPU 1. The attention is also passed forward in time to the next time step and copied to GPU 0 for the next time step of cell
. (Note: This is just an example of use, not a suggested device partitioning strategy.)
Example usage:
cell = # instance of RNNCell if mode == "train": helper = tf.contrib.seq2seq.TrainingHelper( input=input_vectors, sequence_length=input_lengths) elif mode == "infer": helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding, start_tokens=tf.tile([GO_SYMBOL], [batch_size]), end_token=END_SYMBOL) decoder = tf.contrib.seq2seq.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(batch_size, tf.float32)) outputs, _ = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=20)
tf.contrib.seq2seq.Helper
tf.contrib.seq2seq.CustomHelper
tf.contrib.seq2seq.GreedyEmbeddingHelper
tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper
tf.contrib.seq2seq.ScheduledOutputTrainingHelper
tf.contrib.seq2seq.TrainingHelper
© 2017 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_guides/python/contrib.seq2seq