Mixed precision policy API
- Original Link : https://keras.io/api/mixed_precision/policy/
- Last Checked at : 2024-11-25
DTypePolicy
class
keras.dtype_policies.DTypePolicy(name=None)
A dtype policy for a Keras layer.
A dtype policy determines a layer’s computation and variable dtypes. Each
layer has a policy. Policies can be passed to the dtype
argument of layer
constructors, or a global policy can be set with
keras.config.set_dtype_policy
.
Arguments
- name: The policy name, which determines the compute and variable dtypes.
Can be any dtype name, such as
"float32"
or"float64"
, which causes both the compute and variable dtypes will be that dtype. Can also be the string"mixed_float16"
or"mixed_bfloat16"
, which causes the compute dtype to befloat16
orbfloat16
and the variable dtype to befloat32
.
Typically you only need to interact with dtype policies when using mixed
precision, which is the use of float16 or bfloat16 for computations and
float32 for variables. This is why the term mixed_precision
appears in the
API name. Mixed precision can be enabled by passing "mixed_float16"
or
"mixed_bfloat16"
to keras.mixed_precision.set_dtype_policy()
.
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
In the example above, passing dtype="float32"
to the layer is
equivalent to passing
dtype=keras.config.DTypePolicy("float32")
.
In general, passing a dtype policy name to a layer is equivalent
to passing the corresponding policy, so it is never necessary
to explicitly construct a DTypePolicy
object.
DTypePolicyMap
class
keras.dtype_policies.DTypePolicyMap(default_policy=None, policy_map=None)
Dict-like object mapping layer paths to DTypePolicy
instances.
DTypePolicyMap
can be used in get_config
in layers and subclasses to
support a complex configurations of dtype policies.
For example, we can modify get_config
in layers.MultiHeadAttention
as
follows to support the mixing of dtype policies, such as quantization.
@keras.saving.register_keras_serializable("MyPackage")
class MyMultiHeadAttention(keras.layers.MultiHeadAttention):
def get_config(self):
config = super().get_config()
dtype_policy_map = dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.dtype_policy.quantization_mode is not None:
dtype_policy_map[layer.path] = layer.dtype_policy
if len(dtype_policy_map) > 0:
config.update({"dtype": dtype_policy_map})
return config
Internally, DTypePolicyMap
uses a string as a key and a DTypePolicy
as the value. Typically, the key used for querying is the Layer.path
.
However, it is also possible to set a regex as the key. See the docstring of
get
for more details.
See below for a usage example. You can define the naming schema
of the DTypePolicy
, and then retrieve the corresponding DTypePolicy
instance.
dtype_policy_map = DTypePolicyMap()
dtype_policy_map["layer/dense_0"] = DTypePolicy("bfloat16")
dtype_policy_map["layer/dense_1"] = QuantizedDTypePolicy("int8", "bfloat16")
policy_0 = dtype_policy_map["layer/dense_0"]
policy_1 = dtype_policy_map["layer/dense_1"]
policy_2 = dtype_policy_map["layer/dense_2"] # No hit
assert policy_0 == DTypePolicy("bfloat16")
assert policy_1 == QuantizedDTypePolicy("int8", "bfloat16")
assert policy_2 == keras.config.dtype_policy()
Arguments
- default_policy: An optional
DTypePolicy
instance specifying the default dtype policy. If not specified, the value will default tokeras.config.dtype_policy()
. - policy_map: An optional dict that maps string to
DTypePolicy
instances. Defaults toNone
FloatDTypePolicy
class
keras.dtype_policies.FloatDTypePolicy(name=None)
A dtype policy for a Keras layer.
A dtype policy determines a layer’s computation and variable dtypes. Each
layer has a policy. Policies can be passed to the dtype
argument of layer
constructors, or a global policy can be set with
keras.config.set_dtype_policy
.
Arguments
- name: The policy name, which determines the compute and variable dtypes.
Can be any dtype name, such as
"float32"
or"float64"
, which causes both the compute and variable dtypes will be that dtype. Can also be the string"mixed_float16"
or"mixed_bfloat16"
, which causes the compute dtype to befloat16
orbfloat16
and the variable dtype to befloat32
.
Typically you only need to interact with dtype policies when using mixed
precision, which is the use of float16 or bfloat16 for computations and
float32 for variables. This is why the term mixed_precision
appears in the
API name. Mixed precision can be enabled by passing "mixed_float16"
or
"mixed_bfloat16"
to keras.mixed_precision.set_dtype_policy()
.
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
In the example above, passing dtype="float32"
to the layer is
equivalent to passing
dtype=keras.config.DTypePolicy("float32")
.
In general, passing a dtype policy name to a layer is equivalent
to passing the corresponding policy, so it is never necessary
to explicitly construct a DTypePolicy
object.
QuantizedDTypePolicy
class
keras.dtype_policies.QuantizedDTypePolicy(mode, source_name=None)
A dtype policy for a Keras layer.
A dtype policy determines a layer’s computation and variable dtypes. Each
layer has a policy. Policies can be passed to the dtype
argument of layer
constructors, or a global policy can be set with
keras.config.set_dtype_policy
.
Arguments
- name: The policy name, which determines the compute and variable dtypes.
Can be any dtype name, such as
"float32"
or"float64"
, which causes both the compute and variable dtypes will be that dtype. Can also be the string"mixed_float16"
or"mixed_bfloat16"
, which causes the compute dtype to befloat16
orbfloat16
and the variable dtype to befloat32
.
Typically you only need to interact with dtype policies when using mixed
precision, which is the use of float16 or bfloat16 for computations and
float32 for variables. This is why the term mixed_precision
appears in the
API name. Mixed precision can be enabled by passing "mixed_float16"
or
"mixed_bfloat16"
to keras.mixed_precision.set_dtype_policy()
.
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
In the example above, passing dtype="float32"
to the layer is
equivalent to passing
dtype=keras.config.DTypePolicy("float32")
.
In general, passing a dtype policy name to a layer is equivalent
to passing the corresponding policy, so it is never necessary
to explicitly construct a DTypePolicy
object.
QuantizedFloat8DTypePolicy
class
keras.dtype_policies.QuantizedFloat8DTypePolicy(
mode, source_name=None, amax_history_length=1024
)
A dtype policy for a Keras layer.
A dtype policy determines a layer’s computation and variable dtypes. Each
layer has a policy. Policies can be passed to the dtype
argument of layer
constructors, or a global policy can be set with
keras.config.set_dtype_policy
.
Arguments
- name: The policy name, which determines the compute and variable dtypes.
Can be any dtype name, such as
"float32"
or"float64"
, which causes both the compute and variable dtypes will be that dtype. Can also be the string"mixed_float16"
or"mixed_bfloat16"
, which causes the compute dtype to befloat16
orbfloat16
and the variable dtype to befloat32
.
Typically you only need to interact with dtype policies when using mixed
precision, which is the use of float16 or bfloat16 for computations and
float32 for variables. This is why the term mixed_precision
appears in the
API name. Mixed precision can be enabled by passing "mixed_float16"
or
"mixed_bfloat16"
to keras.mixed_precision.set_dtype_policy()
.
>>> keras.config.set_dtype_policy("mixed_float16")
>>> layer1 = keras.layers.Dense(10)
>>> layer1.dtype_policy # layer1 will automatically use mixed precision
<DTypePolicy "mixed_float16">
>>> # Can optionally override layer to use float32
>>> # instead of mixed precision.
>>> layer2 = keras.layers.Dense(10, dtype="float32")
>>> layer2.dtype_policy
<DTypePolicy "float32">
>>> # Set policy back to initial float32.
>>> keras.config.set_dtype_policy('float32')
In the example above, passing dtype="float32"
to the layer is
equivalent to passing
dtype=keras.config.DTypePolicy("float32")
.
In general, passing a dtype policy name to a layer is equivalent
to passing the corresponding policy, so it is never necessary
to explicitly construct a DTypePolicy
object.
dtype_policy
function
keras.config.dtype_policy()
Returns the current default dtype policy object.
set_dtype_policy
function
keras.config.set_dtype_policy(policy)
Sets the default dtype policy globally.
Example
>>> keras.config.set_dtype_policy("mixed_float16")