Point cloud classification with PointNet
- Original Link : https://keras.io/examples/vision/pointnet/
- Last Checked at : 2024-11-20
Author: David Griffiths
Date created: 2020/05/25
Last modified: 2024/01/09
Description: Implementation of PointNet for ModelNet10 classification.
Introduction
Classification, detection and segmentation of unordered 3D point sets i.e. point clouds is a core problem in computer vision. This example implements the seminal point cloud deep learning paper PointNet (Qi et al., 2017). For a detailed intoduction on PointNet see this blog post.
Setup
If using colab first install trimesh with !pip install trimesh
.
import os
import glob
import trimesh
import numpy as np
from tensorflow import data as tf_data
from keras import ops
import keras
from keras import layers
from matplotlib import pyplot as plt
keras.utils.set_random_seed(seed=42)
Load dataset
We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40 dataset. First download the data:
DATA_DIR = keras.utils.get_file(
"modelnet.zip",
"http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip",
extract=True,
)
DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), "ModelNet10")
Result
Downloading data from http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
473402300/473402300 ━━━━━━━━━━━━━━━━━━━━ 12s 0us/step
We can use the trimesh
package to read and visualize the .off
mesh files.
mesh = trimesh.load(os.path.join(DATA_DIR, "chair/train/chair_0001.off"))
mesh.show()
To convert a mesh file to a point cloud we first need to sample points on the mesh surface. .sample()
performs a uniform random sampling. Here we sample at 2048 locations and visualize in matplotlib
.
points = mesh.sample(2048)
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
ax.set_axis_off()
plt.show()
To generate a tf.data.Dataset()
we need to first parse through the ModelNet data folders. Each mesh is loaded and sampled into a point cloud before being added to a standard python list and converted to a numpy
array. We also store the current enumerate index value as the object label and use a dictionary to recall this later.
def parse_dataset(num_points=2048):
train_points = []
train_labels = []
test_points = []
test_labels = []
class_map = {}
folders = glob.glob(os.path.join(DATA_DIR, "[!README]*"))
for i, folder in enumerate(folders):
print("processing class: {}".format(os.path.basename(folder)))
# store folder name with ID so we can retrieve later
class_map[i] = folder.split("/")[-1]
# gather all files
train_files = glob.glob(os.path.join(folder, "train/*"))
test_files = glob.glob(os.path.join(folder, "test/*"))
for f in train_files:
train_points.append(trimesh.load(f).sample(num_points))
train_labels.append(i)
for f in test_files:
test_points.append(trimesh.load(f).sample(num_points))
test_labels.append(i)
return (
np.array(train_points),
np.array(test_points),
np.array(train_labels),
np.array(test_labels),
class_map,
)
Set the number of points to sample and batch size and parse the dataset. This can take ~5minutes to complete.
NUM_POINTS = 2048
NUM_CLASSES = 10
BATCH_SIZE = 32
train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(
NUM_POINTS
)
Result
processing class: bathtub
processing class: monitor
processing class: desk
processing class: dresser
processing class: toilet
processing class: bed
processing class: sofa
processing class: chair
processing class: night_stand
processing class: table
Our data can now be read into a tf.data.Dataset()
object. We set the shuffle buffer size to the entire size of the dataset as prior to this the data is ordered by class. Data augmentation is important when working with point cloud data. We create a augmentation function to jitter and shuffle the train dataset.
def augment(points, label):
# jitter points
points += keras.random.uniform(points.shape, -0.005, 0.005, dtype="float64")
# shuffle points
points = keras.random.shuffle(points)
return points, label
train_size = 0.8
dataset = tf_data.Dataset.from_tensor_slices((train_points, train_labels))
test_dataset = tf_data.Dataset.from_tensor_slices((test_points, test_labels))
train_dataset_size = int(len(dataset) * train_size)
dataset = dataset.shuffle(len(train_points)).map(augment)
test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)
train_dataset = dataset.take(train_dataset_size).batch(BATCH_SIZE)
validation_dataset = dataset.skip(train_dataset_size).batch(BATCH_SIZE)
Build a model
Each convolution and fully-connected layer (with exception for end layers) consists of Convolution / Dense -> Batch Normalization -> ReLU Activation.
def conv_bn(x, filters):
x = layers.Conv1D(filters, kernel_size=1, padding="valid")(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
def dense_bn(x, filters):
x = layers.Dense(filters)(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
PointNet consists of two core components. The primary MLP network, and the transformer net (T-net). The T-net aims to learn an affine transformation matrix by its own mini network. The T-net is used twice. The first time to transform the input features (n, 3) into a canonical representation. The second is an affine transformation for alignment in feature space (n, 3). As per the original paper we constrain the transformation to be close to an orthogonal matrix (i.e. ||X*X^T - I|| = 0).
class OrthogonalRegularizer(keras.regularizers.Regularizer):
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.eye = ops.eye(num_features)
def __call__(self, x):
x = ops.reshape(x, (-1, self.num_features, self.num_features))
xxt = ops.tensordot(x, x, axes=(2, 2))
xxt = ops.reshape(xxt, (-1, self.num_features, self.num_features))
return ops.sum(self.l2reg * ops.square(xxt - self.eye))
We can then define a general function to build T-net layers.
def tnet(inputs, num_features):
# Initialise bias as the identity matrix
bias = keras.initializers.Constant(np.eye(num_features).flatten())
reg = OrthogonalRegularizer(num_features)
x = conv_bn(inputs, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = dense_bn(x, 128)
x = layers.Dense(
num_features * num_features,
kernel_initializer="zeros",
bias_initializer=bias,
activity_regularizer=reg,
)(x)
feat_T = layers.Reshape((num_features, num_features))(x)
# Apply affine transformation to input features
return layers.Dot(axes=(2, 1))([inputs, feat_T])
The main network can be then implemented in the same manner where the t-net mini models can be dropped in a layers in the graph. Here we replicate the network architecture published in the original paper but with half the number of weights at each layer as we are using the smaller 10 class ModelNet dataset.
inputs = keras.Input(shape=(NUM_POINTS, 3))
x = tnet(inputs, 3)
x = conv_bn(x, 32)
x = conv_bn(x, 32)
x = tnet(x, 32)
x = conv_bn(x, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = layers.Dropout(0.3)(x)
x = dense_bn(x, 128)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(NUM_CLASSES, activation="softmax")(x)
model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
model.summary()
Result
Model: "pointnet"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer │ (None, 2048, 3) │ 0 │ - │
│ (InputLayer) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d (Conv1D) │ (None, 2048, 32) │ 128 │ input_layer[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalization │ (None, 2048, 32) │ 128 │ conv1d[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation │ (None, 2048, 32) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_1 (Conv1D) │ (None, 2048, 64) │ 2,112 │ activation[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 64) │ 256 │ conv1d_1[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_1 │ (None, 2048, 64) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_2 (Conv1D) │ (None, 2048, 512) │ 33,280 │ activation_1[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 512) │ 2,048 │ conv1d_2[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_2 │ (None, 2048, 512) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 512) │ 0 │ activation_2[0][0] │
│ (GlobalMaxPooling1… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense (Dense) │ (None, 256) │ 131,328 │ global_max_pooling1… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 256) │ 1,024 │ dense[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_3 │ (None, 256) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_1 (Dense) │ (None, 128) │ 32,896 │ activation_3[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 128) │ 512 │ dense_1[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_4 │ (None, 128) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_2 (Dense) │ (None, 9) │ 1,161 │ activation_4[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape (Reshape) │ (None, 3, 3) │ 0 │ dense_2[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dot (Dot) │ (None, 2048, 3) │ 0 │ input_layer[0][0], │
│ │ │ │ reshape[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_3 (Conv1D) │ (None, 2048, 32) │ 128 │ dot[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 32) │ 128 │ conv1d_3[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_5 │ (None, 2048, 32) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_4 (Conv1D) │ (None, 2048, 32) │ 1,056 │ activation_5[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 32) │ 128 │ conv1d_4[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_6 │ (None, 2048, 32) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_5 (Conv1D) │ (None, 2048, 32) │ 1,056 │ activation_6[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 32) │ 128 │ conv1d_5[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_7 │ (None, 2048, 32) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_6 (Conv1D) │ (None, 2048, 64) │ 2,112 │ activation_7[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 64) │ 256 │ conv1d_6[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_8 │ (None, 2048, 64) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_7 (Conv1D) │ (None, 2048, 512) │ 33,280 │ activation_8[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 512) │ 2,048 │ conv1d_7[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_9 │ (None, 2048, 512) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 512) │ 0 │ activation_9[0][0] │
│ (GlobalMaxPooling1… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_3 (Dense) │ (None, 256) │ 131,328 │ global_max_pooling1… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 256) │ 1,024 │ dense_3[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_10 │ (None, 256) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_4 (Dense) │ (None, 128) │ 32,896 │ activation_10[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 128) │ 512 │ dense_4[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_11 │ (None, 128) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_5 (Dense) │ (None, 1024) │ 132,096 │ activation_11[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ reshape_1 (Reshape) │ (None, 32, 32) │ 0 │ dense_5[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dot_1 (Dot) │ (None, 2048, 32) │ 0 │ activation_6[0][0], │
│ │ │ │ reshape_1[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_8 (Conv1D) │ (None, 2048, 32) │ 1,056 │ dot_1[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 32) │ 128 │ conv1d_8[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_12 │ (None, 2048, 32) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_9 (Conv1D) │ (None, 2048, 64) │ 2,112 │ activation_12[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 64) │ 256 │ conv1d_9[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_13 │ (None, 2048, 64) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ conv1d_10 (Conv1D) │ (None, 2048, 512) │ 33,280 │ activation_13[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 2048, 512) │ 2,048 │ conv1d_10[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_14 │ (None, 2048, 512) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ global_max_pooling… │ (None, 512) │ 0 │ activation_14[0][0] │
│ (GlobalMaxPooling1… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_6 (Dense) │ (None, 256) │ 131,328 │ global_max_pooling1… │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 256) │ 1,024 │ dense_6[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_15 │ (None, 256) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dropout (Dropout) │ (None, 256) │ 0 │ activation_15[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_7 (Dense) │ (None, 128) │ 32,896 │ dropout[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ batch_normalizatio… │ (None, 128) │ 512 │ dense_7[0][0] │
│ (BatchNormalizatio… │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ activation_16 │ (None, 128) │ 0 │ batch_normalization… │
│ (Activation) │ │ │ │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dropout_1 (Dropout) │ (None, 128) │ 0 │ activation_16[0][0] │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_8 (Dense) │ (None, 10) │ 1,290 │ dropout_1[0][0] │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
Total params: 748,979 (2.86 MB)
Trainable params: 742,899 (2.83 MB)
Non-trainable params: 6,080 (23.75 KB)
Train model
Once the model is defined it can be trained like any other standard classification model using .compile()
and .fit()
.
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=keras.optimizers.Adam(learning_rate=0.001),
metrics=["sparse_categorical_accuracy"],
)
model.fit(train_dataset, epochs=20, validation_data=validation_dataset)
Result
Epoch 1/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 16:59 10s/step - loss: 70.7465 - sparse_categorical_accuracy: 0.2188
2/100 [37m━━━━━━━━━━━━━━━━━━━━ 2:06 1s/step - loss: 69.8872 - sparse_categorical_accuracy: 0.1953
...
99/100 ━━━━━━━━━━━━━━━━━━━[37m━ 1s 1s/step - loss: 46.0937 - sparse_categorical_accuracy: 0.2151
100/100 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 46.0345 - sparse_categorical_accuracy: 0.2154
100/100 ━━━━━━━━━━━━━━━━━━━━ 119s 1s/step - loss: 45.9764 - sparse_categorical_accuracy: 0.2156 - val_loss: 4122951.0000 - val_sparse_categorical_accuracy: 0.3154
Epoch 2/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:44 1s/step - loss: 36.7920 - sparse_categorical_accuracy: 0.2500
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 0s 1s/step - loss: 36.6402 - sparse_categorical_accuracy: 0.2749
100/100 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 36.6386 - sparse_categorical_accuracy: 0.2751 - val_loss: 20961250112658389073920.0000 - val_sparse_categorical_accuracy: 0.3191
Epoch 3/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 57:33 35s/step - loss: 35.9745 - sparse_categorical_accuracy: 0.3438
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - loss: 36.4148 - sparse_categorical_accuracy: 0.3150 - val_loss: 14661139300352.0000 - val_sparse_categorical_accuracy: 0.2240
Epoch 4/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:40 1s/step - loss: 36.7380 - sparse_categorical_accuracy: 0.5312
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 110s 1s/step - loss: 36.7658 - sparse_categorical_accuracy: 0.3286 - val_loss: 2640681721921536.0000 - val_sparse_categorical_accuracy: 0.3542
Epoch 5/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:43 1s/step - loss: 36.6004 - sparse_categorical_accuracy: 0.2188
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 112s 1s/step - loss: 36.9928 - sparse_categorical_accuracy: 0.3100 - val_loss: 2087371157504536015273984.0000 - val_sparse_categorical_accuracy: 0.3004
Epoch 6/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:43 1s/step - loss: 37.1168 - sparse_categorical_accuracy: 0.1875
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 36.6758 - sparse_categorical_accuracy: 0.3182 - val_loss: 598952362161209344.0000 - val_sparse_categorical_accuracy: 0.4180
Epoch 7/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:46 1s/step - loss: 36.5799 - sparse_categorical_accuracy: 0.2188
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - loss: 37.8231 - sparse_categorical_accuracy: 0.3192 - val_loss: 1330149064704.0000 - val_sparse_categorical_accuracy: 0.3367
Epoch 8/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:42 1s/step - loss: 36.6512 - sparse_categorical_accuracy: 0.2500
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - loss: 36.4611 - sparse_categorical_accuracy: 0.3198 - val_loss: 55461990629376.0000 - val_sparse_categorical_accuracy: 0.3805
Epoch 9/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:48 1s/step - loss: 36.1902 - sparse_categorical_accuracy: 0.4062
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - loss: 36.3207 - sparse_categorical_accuracy: 0.3371 - val_loss: 79361986265088.0000 - val_sparse_categorical_accuracy: 0.3680
Epoch 10/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 58:50 36s/step - loss: 36.7173 - sparse_categorical_accuracy: 0.4062
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - loss: 36.0947 - sparse_categorical_accuracy: 0.3475 - val_loss: 14927241216.0000 - val_sparse_categorical_accuracy: 0.3054
Epoch 11/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 58:42 36s/step - loss: 36.1768 - sparse_categorical_accuracy: 0.3438
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - loss: 38.5024 - sparse_categorical_accuracy: 0.3187 - val_loss: 1930753792.0000 - val_sparse_categorical_accuracy: 0.2315
Epoch 12/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:00:07 36s/step - loss: 42.1152 - sparse_categorical_accuracy: 0.3750
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - loss: 37.4206 - sparse_categorical_accuracy: 0.3256 - val_loss: 1793616557963500563988480.0000 - val_sparse_categorical_accuracy: 0.2328
Epoch 13/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 59:52 36s/step - loss: 43.0665 - sparse_categorical_accuracy: 0.1875
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 142s 1s/step - loss: 56.5366 - sparse_categorical_accuracy: 0.3226 - val_loss: 505209651200.0000 - val_sparse_categorical_accuracy: 0.2528
Epoch 14/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:46 1s/step - loss: 72.5004 - sparse_categorical_accuracy: 0.2812
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - loss: 407.6649 - sparse_categorical_accuracy: 0.3007 - val_loss: 35970580884750336.0000 - val_sparse_categorical_accuracy: 0.3392
Epoch 15/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:41 1s/step - loss: 67.1360 - sparse_categorical_accuracy: 0.1875
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 106s 1s/step - loss: 77.4712 - sparse_categorical_accuracy: 0.3377 - val_loss: 2983669504.0000 - val_sparse_categorical_accuracy: 0.2966
Epoch 16/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:38 992ms/step - loss: 59.8730 - sparse_categorical_accuracy: 0.2188
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 107s 1s/step - loss: 67.4222 - sparse_categorical_accuracy: 0.3189 - val_loss: 37.0687 - val_sparse_categorical_accuracy: 0.1477
Epoch 17/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 58:50 36s/step - loss: 54.1712 - sparse_categorical_accuracy: 0.5312
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 141s 1s/step - loss: 58.2869 - sparse_categorical_accuracy: 0.3282 - val_loss: 4191578574815232.0000 - val_sparse_categorical_accuracy: 0.3129
Epoch 18/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:39 1s/step - loss: 51.9365 - sparse_categorical_accuracy: 0.4375
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 106s 1s/step - loss: 55.1443 - sparse_categorical_accuracy: 0.3351 - val_loss: 50221851662203486208.0000 - val_sparse_categorical_accuracy: 0.3242
Epoch 19/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:41 1s/step - loss: 48.0290 - sparse_categorical_accuracy: 0.2188
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 108s 1s/step - loss: 49.9001 - sparse_categorical_accuracy: 0.3317 - val_loss: 69256328.0000 - val_sparse_categorical_accuracy: 0.3579
Epoch 20/20
1/100 [37m━━━━━━━━━━━━━━━━━━━━ 1:42 1s/step - loss: 45.8100 - sparse_categorical_accuracy: 0.4062
...
100/100 ━━━━━━━━━━━━━━━━━━━━ 106s 1s/step - loss: 47.4913 - sparse_categorical_accuracy: 0.3404 - val_loss: 1814011445248.0000 - val_sparse_categorical_accuracy: 0.3592
<keras.src.callbacks.history.History at 0x7f596cb7b8e0>
Visualize predictions
We can use matplotlib to visualize our trained model performance.
data = test_dataset.take(1)
points, labels = list(data)[0]
points = points[:8, ...]
labels = labels[:8, ...]
# run test data through model
preds = model.predict(points)
preds = ops.argmax(preds, -1)
points = points.numpy()
# plot points with predicted class and label
fig = plt.figure(figsize=(15, 10))
for i in range(8):
ax = fig.add_subplot(2, 4, i + 1, projection="3d")
ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2])
ax.set_title(
"pred: {:}, label: {:}".format(
CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]]
)
)
ax.set_axis_off()
plt.show()
Result
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 404ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 405ms/step