SwapEMAWeights
- Original Link : https://keras.io/api/callbacks/swap_ema_weights/
- Last Checked at : 2024-11-25
SwapEMAWeights
class
keras.callbacks.SwapEMAWeights(swap_on_epoch=False)
Swaps model weights and EMA weights before and after evaluation.
This callbacks replaces the model’s weight values with the values of the optimizer’s EMA weights (the exponential moving average of the past model weights values, implementing “Polyak averaging”) before model evaluation, and restores the previous weights after evaluation.
The SwapEMAWeights
callback is to be used in conjunction with
an optimizer that sets use_ema=True
.
Note that the weights are swapped in-place in order to save memory. The behavior is undefined if you modify the EMA weights or model weights in other callbacks.
Example
optimizer = SGD(use_ema=True)
model.compile(optimizer=optimizer, loss=..., metrics=...)
model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])
model.fit(
X_train,
Y_train,
callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]
)
Arguments
- swap_on_epoch: whether to perform swapping at
on_epoch_begin()
andon_epoch_end()
. This is useful if you want to use EMA weights for other callbacks such asModelCheckpoint
. Defaults toFalse
.