CachedMultiHeadAttention layer
- 원본 링크 : https://keras.io/api/keras_nlp/modeling_layers/cached_multi_head_attention/
- 최종 확인 : 2024-11-26
CachedMultiHeadAttention class
keras_nlp.layers.CachedMultiHeadAttention(
num_heads,
key_dim,
value_dim=None,
dropout=0.0,
use_bias=True,
output_shape=None,
attention_axes=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
seed=None,
**kwargs
)MultiHeadAttention layer with cache support.
This layer is suitable for use in autoregressive decoding. It can be used to cache decoder self-attention and cross-attention. The forward pass can happen in one of three modes:
- No cache, same as regular multi-head attention.
- Static cache (
cache_update_indexis None). In this case, the cached key/value projections will be used and the input values will be ignored. - Updated cache (
cache_update_indexis not None). In this case, new key/value projections are computed using the input, and spliced into the cache at the specified index.
Note that caching is useful only during inference and should not be used during training.
We use the notation B, T, S below, where B is the batch dimension,
T is the target sequence length, and S in the source sequence length.
Note that during generative decoding, T is usually 1 (you are
generating a target sequence of length one to predict the next token).
Call arguments
- query: Query
Tensorof shape(B, T, dim). - value: Value
Tensorof shape(B, S*, dim). ifcacheis None,S*must equalSand match the shape ofattention_mask. If cacheis notNone,S*can be any length less thanS, and the computed value will be spliced intocacheatcache_update_index. - key: Optional key
Tensorof shape(B, S*, dim). IfcacheisNone,S*must equalSand match the shape ofattention_mask. Ifcacheis notNone,S*can be any length less thanS, and the computed value will be spliced intocacheatcache_update_index. - attention_mask: a boolean mask of shape
(B, T, S).attention_maskprevents attention to certain positions. The boolean mask specifies which query elements can attend to which key elements, 1 indicates attention and 0 indicates no attention. Broadcasting can happen for the missing batch dimensions and the head dimension. - cache: a dense float Tensor. The key/value cache, of shape
[B, 2, S, num_heads, key_dims], whereSmust agree with theattention_maskshape. This argument is intended for use during generation to avoid recomputing intermediate state. - cache_update_index: a int or int Tensor, the index at which to update
cache(usually the index of the current token being processed when running generation). Ifcache_update_index=Nonewhilecacheis set, the cache will not be updated. - training: a boolean indicating whether the layer should behave in training mode or in inference mode.
Returns
An (attention_output, cache) tuple. attention_output is the result
of the computation, of shape (B, T, dim), where T is for target
sequence shapes and dim is the query input last dimension if
output_shape is None. Otherwise, the multi-head outputs are
projected to the shape specified by output_shape. cache is the
updated cache.