tfa.seq2seq.gather_tree

Calculates the full beams from the per-step ids and parent beam ids.

For a given beam, past the time step containing the first decoded end_token all values are filled in with end_token.

step_idsThe predicted token IDs. A int32 Tensor of shape [max_time, batch_size, beam_width].
parent_idsThe parent beam indices. A int32 Tensor of shape [max_time, batch_size, beam_width].
max_sequence_lengthsThe maximum sequence length of each batch. A int32 Tensor of shape [batch_size].
end_tokenThe end token ID.

The reordered token IDs based on parent_ids.

InvalidArgumentErrorif parent_ids contains an invalid index.