Model training APIs
- Original Link : https://keras.io/api/models/model_training_apis/
- Last Checked at : 2024-11-24
compile
method
Model.compile(
optimizer="rmsprop",
loss=None,
loss_weights=None,
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
jit_compile="auto",
auto_scale_loss=True,
)
Configures the model for training.
Example
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
loss=keras.losses.BinaryCrossentropy(),
metrics=[
keras.metrics.BinaryAccuracy(),
keras.metrics.FalseNegatives(),
],
)
Arguments
- optimizer: String (name of optimizer) or optimizer instance. See
keras.optimizers
. - loss: Loss function. May be a string (name of loss function), or a
keras.losses.Loss
instance. Seekeras.losses
. A loss function is any callable with the signatureloss = fn(y_true, y_pred)
, wherey_true
are the ground truth values, andy_pred
are the model’s predictions.y_true
should have shape(batch_size, d0, .. dN)
(except in the case of sparse loss functions such as sparse categorical crossentropy which expects integer arrays of shape(batch_size, d0, .. dN-1)
).y_pred
should have shape(batch_size, d0, .. dN)
. The loss function should return a float tensor. - loss_weights: Optional list or dictionary specifying scalar coefficients (Python floats) to weight the loss contributions of different model outputs. The loss value that will be minimized by the model will then be the weighted sum of all individual losses, weighted by the
loss_weights
coefficients. If a list, it is expected to have a 1:1 mapping to the model’s outputs. If a dict, it is expected to map output names (strings) to scalar coefficients. - metrics: List of metrics to be evaluated by the model during training and testing. Each of this can be a string (name of a built-in function), function or a
keras.metrics.Metric
instance. Seekeras.metrics
. Typically you will usemetrics=['accuracy']
. A function is any callable with the signatureresult = fn(y_true, _pred)
. To specify different metrics for different outputs of a multi-output model, you could also pass a dictionary, such asmetrics={'a':'accuracy', 'b':['accuracy', 'mse']}
. You can also pass a list to specify a metric or a list of metrics for each output, such asmetrics=[['accuracy'], ['accuracy', 'mse']]
ormetrics=['accuracy', ['accuracy', 'mse']]
. When you pass the strings ‘accuracy’ or ‘acc’, we convert this to one ofkeras.metrics.BinaryAccuracy
,keras.metrics.CategoricalAccuracy
,keras.metrics.SparseCategoricalAccuracy
based on the shapes of the targets and of the model output. A similar conversion is done for the strings"crossentropy"
and"ce"
as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via theweighted_metrics
argument instead. - weighted_metrics: List of metrics to be evaluated and weighted by
sample_weight
orclass_weight
during training and testing. - run_eagerly: Bool. If
True
, this model’s forward pass will never be compiled. It is recommended to leave this asFalse
when training (for best performance), and to set it toTrue
when debugging. - steps_per_execution: Int. The number of batches to run during each a single compiled function call. Running multiple batches inside a single compiled function call can greatly improve performance on TPUs or small models with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if
steps_per_execution
is set toN
,Callback.on_batch_begin
andCallback.on_batch_end
methods will only be called everyN
batches (i.e. before/after each compiled function execution). Not supported with the PyTorch backend. - jit_compile: Bool or
"auto"
. Whether to use XLA compilation when compiling a model. Forjax
andtensorflow
backends,jit_compile="auto"
enables XLA compilation if the model supports it, and disabled otherwise. Fortorch
backend,"auto"
will default to eager execution andjit_compile=True
will run withtorch.compile
with the"inductor"
backend. - auto_scale_loss: Bool. If
True
and the model dtype policy is"mixed_float16"
, the passed optimizer will be automatically wrapped in aLossScaleOptimizer
, which will dynamically scale the loss to prevent underflow.
fit
method
Model.fit(
x=None,
y=None,
batch_size=None,
epochs=1,
verbose="auto",
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_batch_size=None,
validation_freq=1,
)
Trains the model for a fixed number of epochs (dataset iterations).
Arguments
- x: Input data. It could be:
- A NumPy array (or array-like), or a list of arrays (in case the model has multiple inputs).
- A tensor, or a list of tensors (in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
- A
tf.data.Dataset
. Should return a tuple of either(inputs, targets)
or(inputs, targets, sample_weights)
. - A
keras.utils.PyDataset
returning(inputs, targets)
or(inputs, targets, sample_weights)
.
- y: Target data. Like the input data
x
, it could be either NumPy array(s) or backend-native tensor(s). Ifx
is a dataset, generator, orkeras.utils.PyDataset
instance,y
should not be specified (since targets will be obtained fromx
). - batch_size: Integer or
None
. Number of samples per gradient update. If unspecified,batch_size
will default to 32. Do not specify thebatch_size
if your data is in the form of datasets, generators, orkeras.utils.PyDataset
instances (since they generate batches). - epochs: Integer. Number of epochs to train the model. An epoch is an iteration over the entire
x
andy
data provided (unless thesteps_per_epoch
flag is set to something other than None). Note that in conjunction withinitial_epoch
,epochs
is to be understood as “final epoch”. The model is not trained for a number of iterations given byepochs
, but merely until the epoch of indexepochs
is reached. - verbose:
"auto"
, 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. “auto” becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, soverbose=2
is recommended when not running interactively (e.g., in a production environment). Defaults to"auto"
. - callbacks: List of
keras.callbacks.Callback
instances. List of callbacks to apply during training. Seekeras.callbacks
. Notekeras.callbacks.ProgbarLogger
andkeras.callbacks.History
callbacks are created automatically and need not be passed tomodel.fit()
.keras.callbacks.ProgbarLogger
is created or not based on theverbose
argument inmodel.fit()
. - validation_split: Float between 0 and 1. Fraction of the training data to be used as validation data. The model will set apart this fraction of the training data, will not train on it, and will evaluate the loss and any model metrics on this data at the end of each epoch. The validation data is selected from the last samples in the
x
andy
data provided, before shuffling. This argument is not supported whenx
is a dataset, generator orkeras.utils.PyDataset
instance. If bothvalidation_data
andvalidation_split
are provided,validation_data
will overridevalidation_split
. - validation_data: Data on which to evaluate the loss and any model metrics at the end of each epoch. The model will not be trained on this data. Thus, note the fact that the validation loss of data provided using
validation_split
orvalidation_data
is not affected by regularization layers like noise and dropout.validation_data
will overridevalidation_split
. It could be:- A tuple
(x_val, y_val)
of NumPy arrays or tensors. - A tuple
(x_val, y_val, val_sample_weights)
of NumPy arrays. - A
tf.data.Dataset
. - A Python generator or
keras.utils.PyDataset
returning(inputs, targets)
or(inputs, targets, sample_weights)
.
- A tuple
- shuffle: Boolean, whether to shuffle the training data before each epoch. This argument is ignored when
x
is a generator or atf.data.Dataset
. - class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function (during training only). This can be useful to tell the model to “pay more attention” to samples from an under-represented class. When
class_weight
is specified and targets have a rank of 2 or greater, eithery
must be one-hot encoded, or an explicit final dimension of1
must be included for sparse class labels. - sample_weight: Optional NumPy array of weights for the training samples, used for weighting the loss function (during training only). You can either pass a flat (1D) NumPy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape
(samples, sequence_length)
, to apply a different weight to every timestep of every sample. This argument is not supported whenx
is a dataset, generator, orkeras.utils.PyDataset
instance, instead provide the sample_weights as the third element ofx
. Note that sample weighting does not apply to metrics specified via themetrics
argument incompile()
. To apply sample weighting to your metrics, you can specify them via theweighted_metrics
incompile()
instead. - initial_epoch: Integer. Epoch at which to start training (useful for resuming a previous training run).
- steps_per_epoch: Integer or
None
. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. When training with input tensors such as backend-native tensors, the defaultNone
is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. Ifx
is atf.data.Dataset
, andsteps_per_epoch
isNone
, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify thesteps_per_epoch
argument. Ifsteps_per_epoch=-1
the training will run indefinitely with an infinitely repeating dataset. - validation_steps: Only relevant if
validation_data
is provided. Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. Ifvalidation_steps
isNone
, validation will run until thevalidation_data
dataset is exhausted. In the case of an infinitely repeated dataset, it will run into an infinite loop. Ifvalidation_steps
is specified and only part of the dataset will be consumed, the evaluation will start from the beginning of the dataset at each epoch. This ensures that the same validation samples are used every time. - validation_batch_size: Integer or
None
. Number of samples per validation batch. If unspecified, will default tobatch_size
. Do not specify thevalidation_batch_size
if your data is in the form of datasets orkeras.utils.PyDataset
instances (since they generate batches). - validation_freq: Only relevant if validation data is provided. Specifies how many training epochs to run before a new validation run is performed, e.g.
validation_freq=2
runs validation every 2 epochs.
Unpacking behavior for iterator-like inputs: A common pattern is to pass an iterator like object such as a tf.data.Dataset
or a keras.utils.PyDataset
to fit()
, which will in fact yield not only features (x
) but optionally targets (y
) and sample weights (sample_weight
). Keras requires that the output of such iterator-likes be unambiguous. The iterator should return a tuple of length 1, 2, or 3, where the optional second and third elements will be used for y
and sample_weight
respectively. Any other type provided will be wrapped in a length-one tuple, effectively treating everything as x
. When yielding dicts, they should still adhere to the top-level tuple structure, e.g. ({"x0": x0, "x1": x1}, y)
. Keras will not attempt to separate features, targets, and weights from the keys of a single dict. A notable unsupported data type is the namedtuple
. The reason is that it behaves like both an ordered datatype (tuple) and a mapping datatype (dict). So given a namedtuple of the form: namedtuple("example_tuple", ["y", "x"])
it is ambiguous whether to reverse the order of the elements when interpreting the value. Even worse is a tuple of the form: namedtuple("other_tuple", ["x", "y", "z"])
where it is unclear if the tuple was intended to be unpacked into x
, y
, and sample_weight
or passed through as a single element to x
.
Returns
A History
object. Its History.history
attribute is a record of training loss values and metrics values at successive epochs, as well as validation loss values and validation metrics values (if applicable).
evaluate
method
Model.evaluate(
x=None,
y=None,
batch_size=None,
verbose="auto",
sample_weight=None,
steps=None,
callbacks=None,
return_dict=False,
**kwargs
)
Returns the loss value & metrics values for the model in test mode.
Computation is done in batches (see the batch_size
arg.)
Arguments
- x: Input data. It could be:
- A NumPy array (or array-like), or a list of arrays (in case the model has multiple inputs).
- A tensor, or a list of tensors (in case the model has multiple inputs).
- A dict mapping input names to the corresponding array/tensors, if the model has named inputs.
- A
tf.data.Dataset
. Should return a tuple of either(inputs, targets)
or(inputs, targets, sample_weights)
. - A generator or
keras.utils.PyDataset
returning(inputs, targets)
or(inputs, targets, sample_weights)
.
- y: Target data. Like the input data
x
, it could be either NumPy array(s) or backend-native tensor(s). Ifx
is atf.data.Dataset
orkeras.utils.PyDataset
instance,y
should not be specified (since targets will be obtained from the iterator/dataset). - batch_size: Integer or
None
. Number of samples per batch of computation. If unspecified,batch_size
will default to 32. Do not specify thebatch_size
if your data is in the form of a dataset, generators, orkeras.utils.PyDataset
instances (since they generate batches). - verbose:
"auto"
, 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = single line."auto"
becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, soverbose=2
is recommended when not running interactively (e.g. in a production environment). Defaults to"auto"
. - sample_weight: Optional NumPy array of weights for the test samples, used for weighting the loss function. You can either pass a flat (1D) NumPy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape
(samples, sequence_length)
, to apply a different weight to every timestep of every sample. This argument is not supported whenx
is a dataset, instead pass sample weights as the third element ofx
. - steps: Integer or
None
. Total number of steps (batches of samples) before declaring the evaluation round finished. Ignored with the default value ofNone
. Ifx
is atf.data.Dataset
andsteps
isNone
, evaluation will run until the dataset is exhausted. - callbacks: List of
keras.callbacks.Callback
instances. List of callbacks to apply during evaluation. - return_dict: If
True
, loss and metric results are returned as a dict, with each key being the name of the metric. IfFalse
, they are returned as a list.
Returns
Scalar test loss (if the model has a single output and no metrics) or list of scalars (if the model has multiple outputs and/or metrics). The attribute model.metrics_names
will give you the display labels for the scalar outputs.
predict
method
Model.predict(x, batch_size=None, verbose="auto", steps=None, callbacks=None)
Generates output predictions for the input samples.
Computation is done in batches. This method is designed for batch processing of large numbers of inputs. It is not intended for use inside of loops that iterate over your data and process small numbers of inputs at a time.
For small numbers of inputs that fit in one batch, directly use __call__()
for faster execution, e.g., model(x)
, or model(x, training=False)
if you have layers such as BatchNormalization
that behave differently during inference.
Note: See this FAQ entry for more details about the difference between Model
methods predict()
and __call__()
.
Arguments
- x: Input samples. It could be:
- A NumPy array (or array-like), or a list of arrays (in case the model has multiple inputs).
- A tensor, or a list of tensors (in case the model has multiple inputs).
- A
tf.data.Dataset
. - A
keras.utils.PyDataset
instance.
- batch_size: Integer or
None
. Number of samples per batch. If unspecified,batch_size
will default to 32. Do not specify thebatch_size
if your data is in the form of dataset, generators, orkeras.utils.PyDataset
instances (since they generate batches). - verbose:
"auto"
, 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = single line."auto"
becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, soverbose=2
is recommended when not running interactively (e.g. in a production environment). Defaults to"auto"
. - steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of
None
. Ifx
is atf.data.Dataset
andsteps
isNone
,predict()
will run until the input dataset is exhausted. - callbacks: List of
keras.callbacks.Callback
instances. List of callbacks to apply during prediction.
Returns
NumPy array(s) of predictions.
train_on_batch
method
Model.train_on_batch(
x, y=None, sample_weight=None, class_weight=None, return_dict=False
)
Runs a single gradient update on a single batch of data.
Arguments
- x: Input data. Must be array-like.
- y: Target data. Must be array-like.
- sample_weight: Optional array of the same length as x, containing weights to apply to the model’s loss for each sample. In the case of temporal data, you can pass a 2D array with shape
(samples, sequence_length)
, to apply a different weight to every timestep of every sample. - class_weight: Optional dictionary mapping class indices (integers) to a weight (float) to apply to the model’s loss for the samples from this class during training. This can be useful to tell the model to “pay more attention” to samples from an under-represented class. When
class_weight
is specified and targets have a rank of 2 or greater, eithery
must be one-hot encoded, or an explicit final dimension of 1 must be included for sparse class labels. - return_dict: If
True
, loss and metric results are returned as a dict, with each key being the name of the metric. IfFalse
, they are returned as a list.
Returns
A scalar loss value (when no metrics and return_dict=False
), a list of loss and metric values (if there are metrics and return_dict=False
), or a dict of metric and loss values (if return_dict=True
).
test_on_batch
method
Model.test_on_batch(x, y=None, sample_weight=None, return_dict=False)
Test the model on a single batch of samples.
Arguments
- x: Input data. Must be array-like.
- y: Target data. Must be array-like.
- sample_weight: Optional array of the same length as x, containing weights to apply to the model’s loss for each sample. In the case of temporal data, you can pass a 2D array with shape
(samples, sequence_length)
, to apply a different weight to every timestep of every sample. - return_dict: If
True
, loss and metric results are returned as a dict, with each key being the name of the metric. IfFalse
, they are returned as a list.
Returns
A scalar loss value (when no metrics and return_dict=False
), a list of loss and metric values (if there are metrics and return_dict=False
), or a dict of metric and loss values (if return_dict=True
).
predict_on_batch
method
Model.predict_on_batch(x)
Returns predictions for a single batch of samples.
Arguments
- x: Input data. It must be array-like.
Returns
NumPy array(s) of predictions.