CachedMultiHeadAttention layer
- Original Link : https://keras.io/api/keras_nlp/modeling_layers/cached_multi_head_attention/
- Last Checked at : 2024-11-26
CachedMultiHeadAttention
class
keras_nlp.layers.CachedMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)
MultiHeadAttention layer with cache support.
This layer is suitable for use in autoregressive decoding. It can be used to cache decoder self-attention and cross-attention. The forward pass can happen in one of three modes:
- No cache, same as regular multi-head attention.
- Static cache (
cache_update_index
is None). In this case, the cached key/value projections will be used and the input values will be ignored. - Updated cache (
cache_update_index
is not None). In this case, new key/value projections are computed using the input, and spliced into the cache at the specified index.
Note that caching is useful only during inference and should not be used during training.
We use the notation B
, T
, S
below, where B
is the batch dimension,
T
is the target sequence length, and S
in the source sequence length.
Note that during generative decoding, T
is usually 1 (you are
generating a target sequence of length one to predict the next token).
Call arguments
- query: Query
Tensor
of shape(B, T, dim)
. - value: Value
Tensor
of shape(B, S*, dim)
. ifcache
is None,
S*must equal
Sand match the shape of
attention_mask. If cache
is notNone
,S*
can be any length less thanS
, and the computed value will be spliced intocache
atcache_update_index
. - key: Optional key
Tensor
of shape(B, S*, dim)
. Ifcache
isNone
,S*
must equalS
and match the shape ofattention_mask
. Ifcache
is notNone
,S*
can be any length less thanS
, and the computed value will be spliced intocache
atcache_update_index
. - attention_mask: a boolean mask of shape
(B, T, S)
.attention_mask
prevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension. - cache: a dense float Tensor. The key/value cache, of shape
[B, 2, S, num_heads, key_dims]
, whereS
must agree with theattention_mask
shape. This argument is intended for use during generation to avoid recomputing intermediate state. - cache_update_index: a int or int Tensor, the index at which to update
cache
(usually the index of the current token being processed when running generation). Ifcache_update_index=None
whilecache
is set, the cache will not be updated. - training: a boolean indicating whether the layer should behave in training mode or in inference mode.
Returns
An (attention_output, cache)
tuple. attention_output
is the result
of the computation, of shape (B, T, dim)
, where T
is for target
sequence shapes and dim
is the query input last dimension if
output_shape
is None
. Otherwise, the multi-head outputs are
projected to the shape specified by output_shape
. cache
is the
updated cache.