MixTransformer backbones
- 원본 링크 : https://keras.io/api/keras_cv/models/backbones/mix_transformer/
- 최종 확인 : 2024-11-25
MiTBackbone class
keras_cv.models.MiTBackbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)Base class for Backbone models.
Backbones are reusable layers of models trained on a standard task such as Imagenet classification that can be reused in other tasks.
from_preset method
MiTBackbone.from_preset()Instantiate MiTBackbone model from preset config and weights.
Arguments
- preset: string. Must be one of “mit_b0”, “mit_b1”, “mit_b2”, “mit_b3”, “mit_b4”, “mit_b5”, “mit_b0_imagenet”. If looking for a preset with pretrained weights, choose one of “mit_b0_imagenet”.
- load_weights: Whether to load pre-trained weights into model.
Defaults to
None, which follows whether the preset has pretrained weights available.
Examples
# Load architecture and weights from preset
model = keras_cv.models.MiTBackbone.from_preset(
"mit_b0_imagenet",
)
# Load randomly initialized model from preset architecture with weights
model = keras_cv.models.MiTBackbone.from_preset(
"mit_b0_imagenet",
load_weights=False,| Preset name | Parameters | Description |
|---|---|---|
| mit_b0 | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b1 | 13.16M | MiT (MixTransformer) model with 8 transformer blocks. |
| mit_b2 | 24.20M | MiT (MixTransformer) model with 16 transformer blocks. |
| mit_b3 | 44.08M | MiT (MixTransformer) model with 28 transformer blocks. |
| mit_b4 | 60.85M | MiT (MixTransformer) model with 41 transformer blocks. |
| mit_b5 | 81.45M | MiT (MixTransformer) model with 52 transformer blocks. |
| mit_b0_imagenet | 3.32M | MiT (MixTransformer) model with 8 transformer blocks. Pre-trained on ImageNet-1K and scores 69% top-1 accuracy on the validation set. |
MiTB0Backbone class
keras_cv.models.MiTB0Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB0Backbone()
output = model(input_data)MiTB1Backbone class
keras_cv.models.MiTB1Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB1Backbone()
output = model(input_data)MiTB2Backbone class
keras_cv.models.MiTB2Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB2Backbone()
output = model(input_data)MiTB3Backbone class
keras_cv.models.MiTB3Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB3Backbone()
output = model(input_data)MiTB4Backbone class
keras_cv.models.MiTB4Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB4Backbone()
output = model(input_data)MiTB5Backbone class
keras_cv.models.MiTB5Backbone(
include_rescaling,
depths,
input_shape=(224, 224, 3),
input_tensor=None,
embedding_dims=None,
**kwargs
)MiT model.
For transfer learning use cases, make sure to read the guide to transfer learning & fine-tuning.
Arguments
- include_rescaling: bool, whether to rescale the inputs. If set to
True, inputs will be passed through a
Rescaling(scale=1 / 255)layer. Defaults to True. - input_shape: optional shape tuple, defaults to (None, None, 3).
- input_tensor: optional Keras tensor (i.e., output of
layers.Input()) to use as image input for the model.
Example
input_data = tf.ones(shape=(8, 224, 224, 3))
# Randomly initialized backbone
model = MiTB5Backbone()
output = model(input_data)