tfm.nlp.tasks.QuestionAnsweringTask

Task object for question answering.

Inherits From: Task

paramsthe task configuration instance, which can be any of dataclass, ConfigDict, namedtuple, etc.
logging_dira string pointing to where the model, summaries etc. will be saved. You can also write additional stuff in this directory.
namethe task name.

logging_dir

task_config

Methods

aggregate_logs

View source

Optional aggregation over logs returned from a validation step.

Given step_logs from a validation step, this function aggregates the logs after each eval_step() (see eval_reduce() function in official/core/base_trainer.py). It runs on CPU and can be used to aggregate metrics during validation, when there are too many metrics that cannot fit into TPU memory. Note that this may increase latency due to data transfer between TPU and CPU. Also, the step output from a validation step may be a tuple with elements from replicas, and a concatenation of the elements is needed in such case.

Args
stateThe current state of training, for example, it can be a sequence of metrics.
step_logsLogs from a validation step. Can be a dictionary.

build_inputs

View source

Returns tf.data.Dataset for sentence_prediction task.

build_losses

View source

Standard interface to compute losses.

Args
labelsoptional label tensors.
model_outputsa nested structure of output tensors.
aux_lossesauxiliary loss tensors, i.e. losses in keras.Model.

Returns
The total loss tensor.

build_metrics

View source

Gets metrics for training/validation.

build_model

View source

[Optional] Creates model architecture.

Returns
A model instance.

create_optimizer

View source

Creates an TF optimizer from configurations.

Args
optimizer_configthe parameters of the Optimization settings.
runtime_configthe parameters of the runtime.
dp_configthe parameter of differential privacy.

Returns
A tf.optimizers.Optimizer object.

inference_step

View source

Performs the forward step.

With distribution strategies, this method runs on devices.

Args
inputsa dictionary of input tensors.
modelthe keras.Model.

Returns
Model outputs.

initialize

View source

[Optional] A callback function used as CheckpointManager's init_fn.

This function will be called when no checkpoint is found for the model. If there is a checkpoint, the checkpoint will be loaded and this function will not be called. You can use this callback function to load a pretrained checkpoint, saved under a directory other than the model_dir.

Args
modelThe keras.Model built or used by this task.

process_compiled_metrics

View source

Process and update compiled_metrics.

call when using compile/fit API.

Args
compiled_metricsthe compiled metrics (model.compiled_metrics).
labelsa tensor or a nested structure of tensors.
model_outputsa tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.

process_metrics

View source

Process and update metrics.

Called when using custom training loop API.

Args
metricsa nested structure of metrics objects. The return of function self.build_metrics.
labelsa tensor or a nested structure of tensors.
model_outputsa tensor or a nested structure of tensors. For example, output of the keras model built by self.build_model.
**kwargsother args.

reduce_aggregated_logs

View source

Optional reduce of aggregated logs over validation steps.

This function reduces aggregated logs at the end of validation, and can be used to compute the final metrics. It runs on CPU and in each eval_end() in base trainer (see eval_end() function in official/core/base_trainer.py).

Args
aggregated_logsAggregated logs over multiple validation steps.
global_stepAn optional variable of global step.

Returns
A dictionary of reduced results.

set_preprocessed_eval_input_path

View source

Sets the path to the preprocessed eval data.

train_step

View source

Does forward and backward.

With distribution strategies, this method runs on devices.

Args
inputsa dictionary of input tensors.
modelthe model, forward pass definition.
optimizerthe optimizer for this training step.
metricsa nested structure of metrics objects.

Returns
A dictionary of logs.

validation_step

View source

Validation step.

With distribution strategies, this method runs on devices.

Args
inputsa dictionary of input tensors.
modelthe keras.Model.
metricsa nested structure of metrics objects.

Returns
A dictionary of logs.

loss'loss'