tfa.seq2seq.InferenceSampler

An inference sampler that uses a custom sampling function.

Inherits From: Sampler

sample_fnA callable that takes outputs and emits tensor sample_ids.
sample_shapeEither a list of integers, or a 1-D Tensor of type int32, the shape of the each sample in the batch returned by sample_fn.
sample_dtypethe dtype of the sample returned by sample_fn.
end_fnA callable that takes sample_ids and emits a bool vector shaped [batch_size] indicating whether each sample is an end token.
next_inputs_fn(Optional) A callable that takes sample_ids and returns the next batch of inputs. If not provided, sample_ids is used as the next batch of inputs.

batch_sizeBatch size of tensor returned by sample.

Returns a scalar int32 tensor. The return value might not available before the invocation of initialize(), in this case, ValueError is raised.

sample_ids_dtypeDType of tensor returned by sample.

Returns a DType. The return value might not available before the invocation of initialize().

sample_ids_shapeShape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape. The return value might not available before the invocation of initialize().

Methods

initialize

View source

initialize the sampler with the input tensors.

This method must be invoked exactly once before calling other methods of the Sampler.

Args
inputsA (structure of) input tensors, it could be a nested tuple or a single tensor.
**kwargsOther kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter.

Returns
(initial_finished, initial_inputs).

next_inputs

View source

Returns (finished, next_inputs, next_state).

sample

View source

Returns sample_ids.