VitDet model


ViTDetBackbone class

    image_shape=(None, None, 3),

An implementation of ViT image encoder.

The ViTDetBackbone uses a windowed transformer encoder and relative positional encodings. The code has been adapted from Segment Anything paper, Segment Anything GitHub and Detectron2.


  • hidden_size (int): The latent dimensionality to be projected into in the output of each stacked windowed transformer encoder.
  • num_layers (int): The number of transformer encoder layers to stack in the Vision Transformer.
  • intermediate_dim (int): The dimensionality of the hidden Dense layer in the transformer MLP head.
  • num_heads (int): the number of heads to use in the MultiHeadAttentionWithRelativePE layer of each transformer encoder.
  • global_attention_layer_indices (list): Indexes for blocks using global attention.
  • image_shape (tuple[int], optional): The size of the input image in (H, W, C) format. Defaults to (None, None, 3).
  • patch_size (int, optional): the patch size to be supplied to the Patching layer to turn input images into a flattened sequence of patches. Defaults to 16.
  • num_output_channels (int, optional): The number of channels (features) in the output (image encodings). Defaults to 256.
  • use_bias (bool, optional): Whether to use bias to project the keys, queries, and values in the attention layer. Defaults to True.
  • use_abs_pos (bool, optional): Whether to add absolute positional embeddings to the output patches. Defaults to True.
  • use_rel_pos (bool, optional): Whether to use relative positional emcodings in the attention layer. Defaults to True.
  • window_size (int, optional): The size of the window for windowed attention in the transformer encoder blocks. Defaults to 14.
  • layer_norm_epsilon (int, optional): The epsilon to use in the layer normalization blocks in transformer encoder. Defaults to 1e-6.


input_data = np.ones((2, 224, 224, 3), dtype="float32")
# Pretrained ViTDetBackbone backbone.
model = keras_hub.models.ViTDetBackbone.from_preset("vit_det")
# Randomly initialized ViTDetBackbone backbone with a custom config.
model = keras_hub.models.ViTDetBackbone(
        image_shape = (16, 16, 3),
        patch_size = 2,
        hidden_size = 4,
        num_layers = 2,
        global_attention_layer_indices = [2, 5, 8, 11],
        intermediate_dim = 4 * 4,
        num_heads = 2,
        num_output_channels = 2,
        window_size = 2,


from_preset method

ViTDetBackbone.from_preset(preset, load_weights=True, **kwargs)

Instantiate a keras_hub.models.Backbone from a model preset.

A preset is a directory of configs, weights and other file assets used to save and load a pre-trained model. The preset can be passed as a one of:

  1. a built-in preset identifier like 'bert_base_en'
  2. a Kaggle Models handle like 'kaggle://user/bert/keras/bert_base_en'
  3. a Hugging Face handle like 'hf://user/bert_base_en'
  4. a path to a local preset directory like './bert_base_en'

This constructor can be called in one of two ways. Either from the base class like keras_hub.models.Backbone.from_preset(), or from a model class like keras_hub.models.GemmaBackbone.from_preset(). If calling from the base class, the subclass of the returning object will be inferred from the config in the preset directory.

For any Backbone subclass, you can run cls.presets.keys() to list all built-in presets available on the class.


  • preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
  • load_weights: bool. If True, the weights will be loaded into the model architecture. If False, the weights will be randomly initialized.


# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(