MiTBackbone model

source

MiTBackbone class

keras_hub.models.MiTBackbone(
    depths,
    num_layers,
    blockwise_num_heads,
    blockwise_sr_ratios,
    max_drop_path_rate,
    patch_sizes,
    strides,
    image_shape=(None, None, 3),
    hidden_dims=None,
    **kwargs
)

A backbone with feature pyramid outputs.

FeaturePyramidBackbone extends Backbone with a single pyramid_outputs property for accessing the feature pyramid outputs of the model. Subclassers should set the pyramid_outputs property during the model constructor.

Example

input_data = np.random.uniform(0, 256, size=(2, 224, 224, 3))
# Convert to feature pyramid output format using ResNet.
backbone = ResNetBackbone.from_preset("resnet50")
model = keras.Model(
    inputs=backbone.inputs, outputs=backbone.pyramid_outputs
)
model(input_data)  # A dict containing the keys ["P2", "P3", "P4", "P5"]

source

from_preset method

MiTBackbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.

Arguments

  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.

Examples

# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
    "gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
    "bert_base_en",
    load_weights=False,
)
Preset nameParametersDescription
mit_b0_ade20k_5123.32MMiT (MixTransformer) model with 8 transformer blocks.
mit_b1_ade20k_51213.16MMiT (MixTransformer) model with 8 transformer blocks.
mit_b2_ade20k_51224.20MMiT (MixTransformer) model with 16 transformer blocks.
mit_b3_ade20k_51244.08MMiT (MixTransformer) model with 28 transformer blocks.
mit_b4_ade20k_51260.85MMiT (MixTransformer) model with 41 transformer blocks.
mit_b5_ade20k_64081.45MMiT (MixTransformer) model with 52 transformer blocks.
mit_b0_cityscapes_10243.32MMiT (MixTransformer) model with 8 transformer blocks.
mit_b1_cityscapes_102413.16MMiT (MixTransformer) model with 8 transformer blocks.
mit_b2_cityscapes_102424.20MMiT (MixTransformer) model with 16 transformer blocks.
mit_b3_cityscapes_102444.08MMiT (MixTransformer) model with 28 transformer blocks.
mit_b4_cityscapes_102460.85MMiT (MixTransformer) model with 41 transformer blocks.
mit_b5_cityscapes_102481.45MMiT (MixTransformer) model with 52 transformer blocks.