MobileViT: A mobile-friendly Transformer-based model for image classification
- Original Link : https://keras.io/examples/vision/mobilevit/
- Last Checked at : 2024-11-19
Author: Sayak Paul
Date created: 2021/10/20
Last modified: 2024/02/11
Description: MobileViT for image classification with combined benefits of convolutions and Transformers.
Introduction
In this example, we implement the MobileViT architecture (Mehta et al.), which combines the benefits of Transformers (Vaswani et al.) and convolutions. With Transformers, we can capture long-range dependencies that result in global representations. With convolutions, we can capture spatial relationships that model locality.
Besides combining the properties of Transformers and convolutions, the authors introduce MobileViT as a general-purpose mobile-friendly backbone for different image recognition tasks. Their findings suggest that, performance-wise, MobileViT is better than other models with the same or higher complexity (MobileNetV3, for example), while being efficient on mobile devices.
Note: This example should be run with Tensorflow 2.13 and higher.
Imports
import os
import tensorflow as tf
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
from keras import layers
from keras import backend
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
Hyperparameters
# Values are from table 4.
patch_size = 4 # 2x2, for the Transformer blocks.
image_size = 256
expansion_factor = 2 # expansion factor for the MobileNetV2 blocks.
MobileViT utilities
The MobileViT architecture is comprised of the following blocks:
- Strided 3x3 convolutions that process the input image.
- MobileNetV2-style inverted residual blocks for downsampling the resolution of the intermediate feature maps.
- MobileViT blocks that combine the benefits of Transformers and convolutions. It is presented in the figure below (taken from the original paper):
def conv_block(x, filters=16, kernel_size=3, strides=2):
conv_layer = layers.Conv2D(
filters,
kernel_size,
strides=strides,
activation=keras.activations.swish,
padding="same",
)
return conv_layer(x)
# Reference: https://github.com/keras-team/keras/blob/e3858739d178fe16a0c77ce7fab88b0be6dbbdc7/keras/applications/imagenet_utils.py#L413C17-L435
def correct_pad(inputs, kernel_size):
img_dim = 2 if backend.image_data_format() == "channels_first" else 1
input_size = inputs.shape[img_dim : (img_dim + 2)]
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if input_size[0] is None:
adjust = (1, 1)
else:
adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)
correct = (kernel_size[0] // 2, kernel_size[1] // 2)
return (
(correct[0] - adjust[0], correct[0]),
(correct[1] - adjust[1], correct[1]),
)
# Reference: https://git.io/JKgtC
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
m = layers.Conv2D(expanded_channels, 1, padding="same", use_bias=False)(x)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
if strides == 2:
m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)
m = layers.DepthwiseConv2D(
3, strides=strides, padding="same" if strides == 1 else "valid", use_bias=False
)(m)
m = layers.BatchNormalization()(m)
m = keras.activations.swish(m)
m = layers.Conv2D(output_channels, 1, padding="same", use_bias=False)(m)
m = layers.BatchNormalization()(m)
if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
return layers.Add()([m, x])
return m
# Reference:
# https://keras.io/examples/vision/image_classification_with_vision_transformer/
def mlp(x, hidden_units, dropout_rate):
for units in hidden_units:
x = layers.Dense(units, activation=keras.activations.swish)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
for _ in range(transformer_layers):
# Layer normalization 1.
x1 = layers.LayerNormalization(epsilon=1e-6)(x)
# Create a multi-head attention layer.
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection 1.
x2 = layers.Add()([attention_output, x])
# Layer normalization 2.
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(
x3,
hidden_units=[x.shape[-1] * 2, x.shape[-1]],
dropout_rate=0.1,
)
# Skip connection 2.
x = layers.Add()([x3, x2])
return x
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
# Local projection with convolutions.
local_features = conv_block(x, filters=projection_dim, strides=strides)
local_features = conv_block(
local_features, filters=projection_dim, kernel_size=1, strides=strides
)
# Unfold into patches and then pass through Transformers.
num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
local_features
)
global_features = transformer_block(
non_overlapping_patches, num_blocks, projection_dim
)
# Fold into conv-like feature-maps.
folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
global_features
)
# Apply point-wise conv -> concatenate with the input features.
folded_feature_map = conv_block(
folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
)
local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])
# Fuse the local and global features using a convoluion layer.
local_global_features = conv_block(
local_global_features, filters=projection_dim, strides=strides
)
return local_global_features
More on the MobileViT block:
- First, the feature representations (A) go through convolution blocks that capture local relationships. The expected shape of a single entry here would be
(h, w, num_channels)
. - Then they get unfolded into another vector with shape
(p, n, num_channels)
, wherep
is the area of a small patch, andn
is(h * w) / p
. So, we end up withn
non-overlapping patches. - This unfolded vector is then passed through a Tranformer block that captures global relationships between the patches.
- The output vector (B) is again folded into a vector of shape
(h, w, num_channels)
resembling a feature map coming out of convolutions.
Vectors A and B are then passed through two more convolutional layers to fuse the local and global representations. Notice how the spatial resolution of the final vector remains unchanged at this point. The authors also present an explanation of how the MobileViT block resembles a convolution block of a CNN. For more details, please refer to the original paper.
Next, we combine these blocks together and implement the MobileViT architecture (XXS variant). The following figure (taken from the original paper) presents a schematic representation of the architecture:
def create_mobilevit(num_classes=5):
inputs = keras.Input((image_size, image_size, 3))
x = layers.Rescaling(scale=1.0 / 255)(inputs)
# Initial conv-stem -> MV2 block.
x = conv_block(x, filters=16)
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=16
)
# Downsampling with MV2 block.
x = inverted_residual_block(
x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=24
)
# First MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
)
x = mobilevit_block(x, num_blocks=2, projection_dim=64)
# Second MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
)
x = mobilevit_block(x, num_blocks=4, projection_dim=80)
# Third MV2 -> MobileViT block.
x = inverted_residual_block(
x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
)
x = mobilevit_block(x, num_blocks=3, projection_dim=96)
x = conv_block(x, filters=320, kernel_size=1, strides=1)
# Classification head.
x = layers.GlobalAvgPool2D()(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
return keras.Model(inputs, outputs)
mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()
Result
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
rescaling (Rescaling) (None, 256, 256, 3) 0 input_1[0][0]
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 128, 128, 16) 448 rescaling[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 128, 128, 32) 512 conv2d[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 128, 128, 32) 128 conv2d_1[0][0]
__________________________________________________________________________________________________
tf.nn.silu (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization[0][0]
__________________________________________________________________________________________________
depthwise_conv2d (DepthwiseConv (None, 128, 128, 32) 288 tf.nn.silu[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 128, 128, 32) 128 depthwise_conv2d[0][0]
__________________________________________________________________________________________________
tf.nn.silu_1 (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 128, 128, 16) 512 tf.nn.silu_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 128, 128, 16) 64 conv2d_2[0][0]
__________________________________________________________________________________________________
add (Add) (None, 128, 128, 16) 0 batch_normalization_2[0][0]
conv2d[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 128, 128, 32) 512 add[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 128, 128, 32) 128 conv2d_3[0][0]
__________________________________________________________________________________________________
tf.nn.silu_2 (TFOpLambda) (None, 128, 128, 32) 0 batch_normalization_3[0][0]
__________________________________________________________________________________________________
zero_padding2d (ZeroPadding2D) (None, 129, 129, 32) 0 tf.nn.silu_2[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_1 (DepthwiseCo (None, 64, 64, 32) 288 zero_padding2d[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 32) 128 depthwise_conv2d_1[0][0]
__________________________________________________________________________________________________
tf.nn.silu_3 (TFOpLambda) (None, 64, 64, 32) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 64, 64, 24) 768 tf.nn.silu_3[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 64, 64, 24) 96 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 64, 64, 48) 1152 batch_normalization_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 64, 64, 48) 192 conv2d_5[0][0]
__________________________________________________________________________________________________
tf.nn.silu_4 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_2 (DepthwiseCo (None, 64, 64, 48) 432 tf.nn.silu_4[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 64, 64, 48) 192 depthwise_conv2d_2[0][0]
__________________________________________________________________________________________________
tf.nn.silu_5 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 64, 64, 24) 1152 tf.nn.silu_5[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 64, 64, 24) 96 conv2d_6[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 64, 64, 24) 0 batch_normalization_8[0][0]
batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 64, 64, 48) 1152 add_1[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 64, 64, 48) 192 conv2d_7[0][0]
__________________________________________________________________________________________________
tf.nn.silu_6 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_3 (DepthwiseCo (None, 64, 64, 48) 432 tf.nn.silu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 64, 64, 48) 192 depthwise_conv2d_3[0][0]
__________________________________________________________________________________________________
tf.nn.silu_7 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 64, 64, 24) 1152 tf.nn.silu_7[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 64, 64, 24) 96 conv2d_8[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 64, 64, 24) 0 batch_normalization_11[0][0]
add_1[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 64, 64, 48) 1152 add_2[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 64, 64, 48) 192 conv2d_9[0][0]
__________________________________________________________________________________________________
tf.nn.silu_8 (TFOpLambda) (None, 64, 64, 48) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
zero_padding2d_1 (ZeroPadding2D (None, 65, 65, 48) 0 tf.nn.silu_8[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_4 (DepthwiseCo (None, 32, 32, 48) 432 zero_padding2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 32, 32, 48) 192 depthwise_conv2d_4[0][0]
__________________________________________________________________________________________________
tf.nn.silu_9 (TFOpLambda) (None, 32, 32, 48) 0 batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 32, 32, 48) 2304 tf.nn.silu_9[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 48) 192 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 32, 32, 64) 27712 batch_normalization_14[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 32, 32, 64) 4160 conv2d_11[0][0]
__________________________________________________________________________________________________
reshape (Reshape) (None, 4, 256, 64) 0 conv2d_12[0][0]
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 4, 256, 64) 128 reshape[0][0]
__________________________________________________________________________________________________
multi_head_attention (MultiHead (None, 4, 256, 64) 33216 layer_normalization[0][0]
layer_normalization[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 4, 256, 64) 0 multi_head_attention[0][0]
reshape[0][0]
__________________________________________________________________________________________________
layer_normalization_1 (LayerNor (None, 4, 256, 64) 128 add_3[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 4, 256, 128) 8320 layer_normalization_1[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 4, 256, 128) 0 dense[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 4, 256, 64) 8256 dropout[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 4, 256, 64) 0 dense_1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 4, 256, 64) 0 dropout_1[0][0]
add_3[0][0]
__________________________________________________________________________________________________
layer_normalization_2 (LayerNor (None, 4, 256, 64) 128 add_4[0][0]
__________________________________________________________________________________________________
multi_head_attention_1 (MultiHe (None, 4, 256, 64) 33216 layer_normalization_2[0][0]
layer_normalization_2[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 4, 256, 64) 0 multi_head_attention_1[0][0]
add_4[0][0]
__________________________________________________________________________________________________
layer_normalization_3 (LayerNor (None, 4, 256, 64) 128 add_5[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 4, 256, 128) 8320 layer_normalization_3[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 4, 256, 128) 0 dense_2[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 4, 256, 64) 8256 dropout_2[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 4, 256, 64) 0 dense_3[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 4, 256, 64) 0 dropout_3[0][0]
add_5[0][0]
__________________________________________________________________________________________________
reshape_1 (Reshape) (None, 32, 32, 64) 0 add_6[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 32, 32, 48) 3120 reshape_1[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 32, 32, 96) 0 batch_normalization_14[0][0]
conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 32, 32, 64) 55360 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 32, 32, 128) 8192 conv2d_14[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 32, 32, 128) 512 conv2d_15[0][0]
__________________________________________________________________________________________________
tf.nn.silu_10 (TFOpLambda) (None, 32, 32, 128) 0 batch_normalization_15[0][0]
__________________________________________________________________________________________________
zero_padding2d_2 (ZeroPadding2D (None, 33, 33, 128) 0 tf.nn.silu_10[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_5 (DepthwiseCo (None, 16, 16, 128) 1152 zero_padding2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 16, 16, 128) 512 depthwise_conv2d_5[0][0]
__________________________________________________________________________________________________
tf.nn.silu_11 (TFOpLambda) (None, 16, 16, 128) 0 batch_normalization_16[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 16, 16, 64) 8192 tf.nn.silu_11[0][0]
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 16, 16, 64) 256 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 16, 16, 80) 46160 batch_normalization_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 16, 16, 80) 6480 conv2d_17[0][0]
__________________________________________________________________________________________________
reshape_2 (Reshape) (None, 4, 64, 80) 0 conv2d_18[0][0]
__________________________________________________________________________________________________
layer_normalization_4 (LayerNor (None, 4, 64, 80) 160 reshape_2[0][0]
__________________________________________________________________________________________________
multi_head_attention_2 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_4[0][0]
layer_normalization_4[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 4, 64, 80) 0 multi_head_attention_2[0][0]
reshape_2[0][0]
__________________________________________________________________________________________________
layer_normalization_5 (LayerNor (None, 4, 64, 80) 160 add_7[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 4, 64, 160) 12960 layer_normalization_5[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 4, 64, 160) 0 dense_4[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 4, 64, 80) 12880 dropout_4[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 4, 64, 80) 0 dense_5[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 4, 64, 80) 0 dropout_5[0][0]
add_7[0][0]
__________________________________________________________________________________________________
layer_normalization_6 (LayerNor (None, 4, 64, 80) 160 add_8[0][0]
__________________________________________________________________________________________________
multi_head_attention_3 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_6[0][0]
layer_normalization_6[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 4, 64, 80) 0 multi_head_attention_3[0][0]
add_8[0][0]
__________________________________________________________________________________________________
layer_normalization_7 (LayerNor (None, 4, 64, 80) 160 add_9[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 4, 64, 160) 12960 layer_normalization_7[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout) (None, 4, 64, 160) 0 dense_6[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 4, 64, 80) 12880 dropout_6[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout) (None, 4, 64, 80) 0 dense_7[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 4, 64, 80) 0 dropout_7[0][0]
add_9[0][0]
__________________________________________________________________________________________________
layer_normalization_8 (LayerNor (None, 4, 64, 80) 160 add_10[0][0]
__________________________________________________________________________________________________
multi_head_attention_4 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_8[0][0]
layer_normalization_8[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 4, 64, 80) 0 multi_head_attention_4[0][0]
add_10[0][0]
__________________________________________________________________________________________________
layer_normalization_9 (LayerNor (None, 4, 64, 80) 160 add_11[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 4, 64, 160) 12960 layer_normalization_9[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout) (None, 4, 64, 160) 0 dense_8[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 4, 64, 80) 12880 dropout_8[0][0]
__________________________________________________________________________________________________
dropout_9 (Dropout) (None, 4, 64, 80) 0 dense_9[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 4, 64, 80) 0 dropout_9[0][0]
add_11[0][0]
__________________________________________________________________________________________________
layer_normalization_10 (LayerNo (None, 4, 64, 80) 160 add_12[0][0]
__________________________________________________________________________________________________
multi_head_attention_5 (MultiHe (None, 4, 64, 80) 51760 layer_normalization_10[0][0]
layer_normalization_10[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 4, 64, 80) 0 multi_head_attention_5[0][0]
add_12[0][0]
__________________________________________________________________________________________________
layer_normalization_11 (LayerNo (None, 4, 64, 80) 160 add_13[0][0]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 4, 64, 160) 12960 layer_normalization_11[0][0]
__________________________________________________________________________________________________
dropout_10 (Dropout) (None, 4, 64, 160) 0 dense_10[0][0]
__________________________________________________________________________________________________
dense_11 (Dense) (None, 4, 64, 80) 12880 dropout_10[0][0]
__________________________________________________________________________________________________
dropout_11 (Dropout) (None, 4, 64, 80) 0 dense_11[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 4, 64, 80) 0 dropout_11[0][0]
add_13[0][0]
__________________________________________________________________________________________________
reshape_3 (Reshape) (None, 16, 16, 80) 0 add_14[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 16, 16, 64) 5184 reshape_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 16, 16, 128) 0 batch_normalization_17[0][0]
conv2d_19[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 16, 16, 80) 92240 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 16, 16, 160) 12800 conv2d_20[0][0]
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 16, 16, 160) 640 conv2d_21[0][0]
__________________________________________________________________________________________________
tf.nn.silu_12 (TFOpLambda) (None, 16, 16, 160) 0 batch_normalization_18[0][0]
__________________________________________________________________________________________________
zero_padding2d_3 (ZeroPadding2D (None, 17, 17, 160) 0 tf.nn.silu_12[0][0]
__________________________________________________________________________________________________
depthwise_conv2d_6 (DepthwiseCo (None, 8, 8, 160) 1440 zero_padding2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 8, 8, 160) 640 depthwise_conv2d_6[0][0]
__________________________________________________________________________________________________
tf.nn.silu_13 (TFOpLambda) (None, 8, 8, 160) 0 batch_normalization_19[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 8, 8, 80) 12800 tf.nn.silu_13[0][0]
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, 8, 8, 80) 320 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 8, 8, 96) 69216 batch_normalization_20[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 8, 8, 96) 9312 conv2d_23[0][0]
__________________________________________________________________________________________________
reshape_4 (Reshape) (None, 4, 16, 96) 0 conv2d_24[0][0]
__________________________________________________________________________________________________
layer_normalization_12 (LayerNo (None, 4, 16, 96) 192 reshape_4[0][0]
__________________________________________________________________________________________________
multi_head_attention_6 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_12[0][0]
layer_normalization_12[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 4, 16, 96) 0 multi_head_attention_6[0][0]
reshape_4[0][0]
__________________________________________________________________________________________________
layer_normalization_13 (LayerNo (None, 4, 16, 96) 192 add_15[0][0]
__________________________________________________________________________________________________
dense_12 (Dense) (None, 4, 16, 192) 18624 layer_normalization_13[0][0]
__________________________________________________________________________________________________
dropout_12 (Dropout) (None, 4, 16, 192) 0 dense_12[0][0]
__________________________________________________________________________________________________
dense_13 (Dense) (None, 4, 16, 96) 18528 dropout_12[0][0]
__________________________________________________________________________________________________
dropout_13 (Dropout) (None, 4, 16, 96) 0 dense_13[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 4, 16, 96) 0 dropout_13[0][0]
add_15[0][0]
__________________________________________________________________________________________________
layer_normalization_14 (LayerNo (None, 4, 16, 96) 192 add_16[0][0]
__________________________________________________________________________________________________
multi_head_attention_7 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_14[0][0]
layer_normalization_14[0][0]
__________________________________________________________________________________________________
add_17 (Add) (None, 4, 16, 96) 0 multi_head_attention_7[0][0]
add_16[0][0]
__________________________________________________________________________________________________
layer_normalization_15 (LayerNo (None, 4, 16, 96) 192 add_17[0][0]
__________________________________________________________________________________________________
dense_14 (Dense) (None, 4, 16, 192) 18624 layer_normalization_15[0][0]
__________________________________________________________________________________________________
dropout_14 (Dropout) (None, 4, 16, 192) 0 dense_14[0][0]
__________________________________________________________________________________________________
dense_15 (Dense) (None, 4, 16, 96) 18528 dropout_14[0][0]
__________________________________________________________________________________________________
dropout_15 (Dropout) (None, 4, 16, 96) 0 dense_15[0][0]
__________________________________________________________________________________________________
add_18 (Add) (None, 4, 16, 96) 0 dropout_15[0][0]
add_17[0][0]
__________________________________________________________________________________________________
layer_normalization_16 (LayerNo (None, 4, 16, 96) 192 add_18[0][0]
__________________________________________________________________________________________________
multi_head_attention_8 (MultiHe (None, 4, 16, 96) 74400 layer_normalization_16[0][0]
layer_normalization_16[0][0]
__________________________________________________________________________________________________
add_19 (Add) (None, 4, 16, 96) 0 multi_head_attention_8[0][0]
add_18[0][0]
__________________________________________________________________________________________________
layer_normalization_17 (LayerNo (None, 4, 16, 96) 192 add_19[0][0]
__________________________________________________________________________________________________
dense_16 (Dense) (None, 4, 16, 192) 18624 layer_normalization_17[0][0]
__________________________________________________________________________________________________
dropout_16 (Dropout) (None, 4, 16, 192) 0 dense_16[0][0]
__________________________________________________________________________________________________
dense_17 (Dense) (None, 4, 16, 96) 18528 dropout_16[0][0]
__________________________________________________________________________________________________
dropout_17 (Dropout) (None, 4, 16, 96) 0 dense_17[0][0]
__________________________________________________________________________________________________
add_20 (Add) (None, 4, 16, 96) 0 dropout_17[0][0]
add_19[0][0]
__________________________________________________________________________________________________
reshape_5 (Reshape) (None, 8, 8, 96) 0 add_20[0][0]
__________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 8, 8, 80) 7760 reshape_5[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 8, 8, 160) 0 batch_normalization_20[0][0]
conv2d_25[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 8, 8, 96) 138336 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 8, 8, 320) 31040 conv2d_26[0][0]
__________________________________________________________________________________________________
global_average_pooling2d (Globa (None, 320) 0 conv2d_27[0][0]
__________________________________________________________________________________________________
dense_18 (Dense) (None, 5) 1605 global_average_pooling2d[0][0]
==================================================================================================
Total params: 1,307,621
Trainable params: 1,305,077
Non-trainable params: 2,544
__________________________________________________________________________________________________
Dataset preparation
We will be using the tf_flowers
dataset to demonstrate the model. Unlike other Transformer-based architectures, MobileViT uses a simple augmentation pipeline primarily because it has the properties of a CNN.
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5
def preprocess_dataset(is_training=True):
def _pp(image, label):
if is_training:
# Resize to a bigger spatial resolution and take the random
# crops.
image = tf.image.resize(image, (resize_bigger, resize_bigger))
image = tf.image.random_crop(image, (image_size, image_size, 3))
image = tf.image.random_flip_left_right(image)
else:
image = tf.image.resize(image, (image_size, image_size))
label = tf.one_hot(label, depth=num_classes)
return image, label
return _pp
def prepare_dataset(dataset, is_training=True):
if is_training:
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
return dataset.batch(batch_size).prefetch(auto)
The authors use a multi-scale data sampler to help the model learn representations of varied scales. In this example, we discard this part.
Load and prepare the dataset
train_dataset, val_dataset = tfds.load(
"tf_flowers", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f"Number of training examples: {num_train}")
print(f"Number of validation examples: {num_val}")
train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)
Result
Number of training examples: 3303
Number of validation examples: 367
Train a MobileViT (XXS) model
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30
optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn = keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)
def run_experiment(epochs=epochs):
mobilevit_xxs = create_mobilevit(num_classes=num_classes)
mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=["accuracy"])
# When using `save_weights_only=True` in `ModelCheckpoint`, the filepath provided must end in `.weights.h5`
checkpoint_filepath = "/tmp/checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
mobilevit_xxs.fit(
train_dataset,
validation_data=val_dataset,
epochs=epochs,
callbacks=[checkpoint_callback],
)
mobilevit_xxs.load_weights(checkpoint_filepath)
_, accuracy = mobilevit_xxs.evaluate(val_dataset)
print(f"Validation accuracy: {round(accuracy * 100, 2)}%")
return mobilevit_xxs
mobilevit_xxs = run_experiment()
Result
Epoch 1/30 52/52 [==============================] - 47s 459ms/step - loss: 1.3397 - accuracy: 0.4832 - val_loss: 1.7250 - val_accuracy: 0.1662
Epoch 2/30 52/52 [==============================] - 21s 404ms/step - loss: 1.1167 - accuracy: 0.6210 - val_loss: 1.9844 - val_accuracy: 0.1907
Epoch 3/30 52/52 [==============================] - 21s 403ms/step - loss: 1.0217 - accuracy: 0.6709 - val_loss: 1.8187 - val_accuracy: 0.1907
Epoch 4/30 52/52 [==============================] - 21s 409ms/step - loss: 0.9682 - accuracy: 0.7048 - val_loss: 2.0329 - val_accuracy: 0.1907
Epoch 5/30 52/52 [==============================] - 21s 408ms/step - loss: 0.9552 - accuracy: 0.7196 - val_loss: 2.1150 - val_accuracy: 0.1907
Epoch 6/30 52/52 [==============================] - 21s 407ms/step - loss: 0.9186 - accuracy: 0.7318 - val_loss: 2.9713 - val_accuracy: 0.1907
Epoch 7/30 52/52 [==============================] - 21s 407ms/step - loss: 0.8986 - accuracy: 0.7457 - val_loss: 3.2062 - val_accuracy: 0.1907
Epoch 8/30 52/52 [==============================] - 21s 408ms/step - loss: 0.8831 - accuracy: 0.7542 - val_loss: 3.8631 - val_accuracy: 0.1907
Epoch 9/30 52/52 [==============================] - 21s 408ms/step - loss: 0.8433 - accuracy: 0.7714 - val_loss: 1.8029 - val_accuracy: 0.3542
Epoch 10/30 52/52 [==============================] - 21s 408ms/step - loss: 0.8489 - accuracy: 0.7763 - val_loss: 1.7920 - val_accuracy: 0.4796
Epoch 11/30 52/52 [==============================] - 21s 409ms/step - loss: 0.8256 - accuracy: 0.7884 - val_loss: 1.4992 - val_accuracy: 0.5477
Epoch 12/30 52/52 [==============================] - 21s 407ms/step - loss: 0.7859 - accuracy: 0.8123 - val_loss: 0.9236 - val_accuracy: 0.7330
Epoch 13/30 52/52 [==============================] - 21s 409ms/step - loss: 0.7702 - accuracy: 0.8159 - val_loss: 0.8059 - val_accuracy: 0.8011
Epoch 14/30 52/52 [==============================] - 21s 403ms/step - loss: 0.7670 - accuracy: 0.8153 - val_loss: 1.1535 - val_accuracy: 0.7084
Epoch 15/30 52/52 [==============================] - 21s 408ms/step - loss: 0.7332 - accuracy: 0.8344 - val_loss: 0.7746 - val_accuracy: 0.8147
Epoch 16/30 52/52 [==============================] - 21s 404ms/step - loss: 0.7284 - accuracy: 0.8335 - val_loss: 1.0342 - val_accuracy: 0.7330
Epoch 17/30 52/52 [==============================] - 21s 409ms/step - loss: 0.7484 - accuracy: 0.8262 - val_loss: 1.0523 - val_accuracy: 0.7112
Epoch 18/30 52/52 [==============================] - 21s 408ms/step - loss: 0.7209 - accuracy: 0.8450 - val_loss: 0.8146 - val_accuracy: 0.8174
Epoch 19/30 52/52 [==============================] - 21s 409ms/step - loss: 0.7141 - accuracy: 0.8435 - val_loss: 0.8016 - val_accuracy: 0.7875
Epoch 20/30 52/52 [==============================] - 21s 410ms/step - loss: 0.7075 - accuracy: 0.8435 - val_loss: 0.9352 - val_accuracy: 0.7439
Epoch 21/30 52/52 [==============================] - 21s 406ms/step - loss: 0.7066 - accuracy: 0.8504 - val_loss: 1.0171 - val_accuracy: 0.7139
Epoch 22/30 52/52 [==============================] - 21s 405ms/step - loss: 0.6913 - accuracy: 0.8532 - val_loss: 0.7059 - val_accuracy: 0.8610
Epoch 23/30 52/52 [==============================] - 21s 408ms/step - loss: 0.6681 - accuracy: 0.8671 - val_loss: 0.8007 - val_accuracy: 0.8147
Epoch 24/30 52/52 [==============================] - 21s 409ms/step - loss: 0.6636 - accuracy: 0.8747 - val_loss: 0.9490 - val_accuracy: 0.7302
Epoch 25/30 52/52 [==============================] - 21s 408ms/step - loss: 0.6637 - accuracy: 0.8722 - val_loss: 0.6913 - val_accuracy: 0.8556
Epoch 26/30 52/52 [==============================] - 21s 406ms/step - loss: 0.6443 - accuracy: 0.8837 - val_loss: 1.0483 - val_accuracy: 0.7139
Epoch 27/30 52/52 [==============================] - 21s 407ms/step - loss: 0.6555 - accuracy: 0.8695 - val_loss: 0.9448 - val_accuracy: 0.7602
Epoch 28/30 52/52 [==============================] - 21s 409ms/step - loss: 0.6409 - accuracy: 0.8807 - val_loss: 0.9337 - val_accuracy: 0.7302
Epoch 29/30 52/52 [==============================] - 21s 408ms/step - loss: 0.6300 - accuracy: 0.8910 - val_loss: 0.7461 - val_accuracy: 0.8256
Epoch 30/30 52/52 [==============================] - 21s 408ms/step - loss: 0.6093 - accuracy: 0.8968 - val_loss: 0.8651 - val_accuracy: 0.7766
6/6 [==============================] - 0s 65ms/step - loss: 0.7059 - accuracy: 0.8610 Validation accuracy: 86.1%
Results and TFLite conversion
With about one million parameters, getting to ~85% top-1 accuracy on 256x256 resolution is a strong result. This MobileViT mobile is fully compatible with TensorFlow Lite (TFLite) and can be converted with the following code:
# Serialize the model as a SavedModel.
tf.saved_model.save(mobilevit_xxs, "mobilevit_xxs")
# Convert to TFLite. This form of quantization is called
# post-training dynamic-range quantization in TFLite.
converter = tf.lite.TFLiteConverter.from_saved_model("mobilevit_xxs")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # Enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS, # Enable TensorFlow ops.
]
tflite_model = converter.convert()
open("mobilevit_xxs.tflite", "wb").write(tflite_model)
To learn more about different quantization recipes available in TFLite and running inference with TFLite models, check out this official resource.
You can use the trained model hosted on Hugging Face Hub and try the demo on Hugging Face Spaces.