SAMImageSegmenter model
- 원본 링크 : https://keras.io/api/keras_hub/models/sam/sam_image_segmenter/
- 최종 확인 : 2024-11-26
SAMImageSegmenter
class
keras_hub.models.SAMImageSegmenter(backbone, preprocessor=None, **kwargs)
The Segment Anything (SAM) image segmenter Model.
SAM works by prompting the input images. There are three ways to prompt:
(1) Labelled Points: Foreground points (points with label 1) are encoded
such that the output masks generated by the mask decoder contain them
and background points (points with label 0) are encoded such that the
generated masks don’t contain them.
(2) Box: A box tells the model which part/crop of the image to segment.
(3) Mask: An input mask can be used to refine the output of the mask
decoder.
These prompts can be mixed and matched but at least one of the prompts
must be present. To turn off a particular prompt, simply exclude it from
the inputs to the model.
(1) For points prompts, the expected shape is (batch, num_points, 2)
.
The labels must have a corresponding shape of (batch, num_points)
.
(2) For box prompt, the expected shape is (batch, 1, 2, 2)
.
(3) Similarly, mask prompts have shape (batch, 1, H, W, 1)
.
Arguments
- backbone: A
keras_hub.models.SAMBackbone
instance.
Example
Load pretrained model using from_preset
.
image_size=128
batch_size=2
input_data = {
"images": np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
),
"points": np.ones((batch_size, 1, 2), dtype="float32"),
"labels": np.ones((batch_size, 1), dtype="float32"),
"boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
"masks": np.zeros(
(batch_size, 0, image_size, image_size, 1)
),
}
# todo: update preset name
sam = keras_hub.models.SAMImageSegmenter.from_preset(`sam_base`)
sam(input_data)
Load segment anything image segmenter with custom backbone
image_size = 128
batch_size = 2
images = np.ones(
(batch_size, image_size, image_size, 3),
dtype="float32",
)
image_encoder = ViTDetBackbone(
hidden_size=16,
num_layers=16,
intermediate_dim=16 * 4,
num_heads=16,
global_attention_layer_indices=[2, 5, 8, 11],
patch_size=16,
num_output_channels=8,
window_size=2,
image_shape=(image_size, image_size, 3),
)
prompt_encoder = SAMPromptEncoder(
hidden_size=8,
image_embedding_size=(8, 8),
input_image_size=(
image_size,
image_size,
),
mask_in_channels=16,
)
mask_decoder = SAMMaskDecoder(
num_layers=2,
hidden_size=8,
intermediate_dim=32,
num_heads=8,
embedding_dim=8,
num_multimask_outputs=3,
iou_head_depth=3,
iou_head_hidden_dim=8,
)
backbone = SAMBackbone(
image_encoder=image_encoder,
prompt_encoder=prompt_encoder,
mask_decoder=mask_decoder,
image_shape=(image_size, image_size, 3),
)
sam = SAMImageSegmenter(
backbone=backbone
)
For example, to pass in all the prompts, do:
points = np.array([[[512., 512.], [100., 100.]]])
# For labels: 1 means foreground point, 0 means background
labels = np.array([[1., 0.]])
box = np.array([[[[384., 384.], [640., 640.]]]])
input_mask = np.ones((1, 1, 256, 256, 1))
Prepare an input dictionary:
inputs = {
"images": image,
"points": points,
"labels": labels,
"boxes": box,
"masks": input_mask
}
outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
The first mask in the output masks
(i.e. masks[:, 0, ...]
) is the best
mask predicted by the model based on the prompts. Other masks
(i.e. masks[:, 1:, ...]
) are alternate predictions that can be used if
they are desired over the first one.
Now, in case of only points and box prompts, simply exclude the masks:
inputs = {
"images": image,
"points": points,
"labels": labels,
"boxes": box,
}
outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
Another example is that only points prompts are present. Note that if point prompts are present but no box prompt is present, the points must be padded using a zero point and -1 label:
padded_points = np.concatenate(
[points, np.zeros((1, 1, 2))], axis=1
)
padded_labels = np.concatenate(
[labels, -np.ones((1, 1))], axis=1
)
inputs = {
"images": image,
"points": padded_points,
"labels": padded_labels,
}
outputs = sam.predict(inputs)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
from_preset
method
SAMImageSegmenter.from_preset(preset, load_weights=True, **kwargs)
Instantiate a keras_hub.models.Task
from a model preset.
A preset is a directory of configs, weights and other file assets used
to save and load a pre-trained model. The preset
can be passed as
one of:
- a built-in preset identifier like
'bert_base_en'
- a Kaggle Models handle like
'kaggle://user/bert/keras/bert_base_en'
- a Hugging Face handle like
'hf://user/bert_base_en'
- a path to a local preset directory like
'./bert_base_en'
For any Task
subclass, you can run cls.presets.keys()
to list all
built-in presets available on the class.
This constructor can be called in one of two ways. Either from a task
specific base class like keras_hub.models.CausalLM.from_preset()
, or
from a model class like keras_hub.models.BertTextClassifier.from_preset()
.
If calling from the a base class, the subclass of the returning object
will be inferred from the config in the preset directory.
Arguments
- preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory.
- load_weights: bool. If
True
, saved weights will be loaded into the model architecture. IfFalse
, all weights will be randomly initialized.
Examples
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
Preset name | Parameters | Description |
---|---|---|
sam_base_sa1b | 93.74M | The base SAM model trained on the SA1B dataset. |
sam_large_sa1b | 641.09M | The large SAM model trained on the SA1B dataset. |
sam_huge_sa1b | 312.34M | The huge SAM model trained on the SA1B dataset. |
backbone
property
keras_hub.models.SAMImageSegmenter.backbone
A keras_hub.models.Backbone
model with the core architecture.
preprocessor
property
keras_hub.models.SAMImageSegmenter.preprocessor
A keras_hub.models.Preprocessor
layer used to preprocess input.