Imbalanced classification: credit card fraud detection
- Original Link : https://keras.io/examples/structured_data/imbalanced_classification/
- Last Checked at : 2024-11-22
Author: fchollet
Date created: 2019/05/28
Last modified: 2020/04/17
Description: Demonstration of how to handle highly imbalanced classification problems.
Introduction
This example looks at the Kaggle Credit Card Fraud Detection dataset to demonstrate how to train a classification model on data with highly imbalanced classes.
First, vectorize the CSV data
import csv
import numpy as np
# Get the real data from https://www.kaggle.com/mlg-ulb/creditcardfraud/
fname = "/Users/fchollet/Downloads/creditcard.csv"
all_features = []
all_targets = []
with open(fname) as f:
for i, line in enumerate(f):
if i == 0:
print("HEADER:", line.strip())
continue # Skip header
fields = line.strip().split(",")
all_features.append([float(v.replace('"', "")) for v in fields[:-1]])
all_targets.append([int(fields[-1].replace('"', ""))])
if i == 1:
print("EXAMPLE FEATURES:", all_features[-1])
features = np.array(all_features, dtype="float32")
targets = np.array(all_targets, dtype="uint8")
print("features.shape:", features.shape)
print("targets.shape:", targets.shape)
Result
HEADER: "Time","V1","V2","V3","V4","V5","V6","V7","V8","V9","V10","V11","V12","V13","V14","V15","V16","V17","V18","V19","V20","V21","V22","V23","V24","V25","V26","V27","V28","Amount","Class"
EXAMPLE FEATURES: [0.0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, 149.62]
features.shape: (284807, 30)
targets.shape: (284807, 1)
Prepare a validation set
num_val_samples = int(len(features) * 0.2)
train_features = features[:-num_val_samples]
train_targets = targets[:-num_val_samples]
val_features = features[-num_val_samples:]
val_targets = targets[-num_val_samples:]
print("Number of training samples:", len(train_features))
print("Number of validation samples:", len(val_features))
Result
Number of training samples: 227846
Number of validation samples: 56961
Analyze class imbalance in the targets
counts = np.bincount(train_targets[:, 0])
print(
"Number of positive samples in training data: {} ({:.2f}% of total)".format(
counts[1], 100 * float(counts[1]) / len(train_targets)
)
)
weight_for_0 = 1.0 / counts[0]
weight_for_1 = 1.0 / counts[1]
Result
Number of positive samples in training data: 417 (0.18% of total)
Normalize the data using training set statistics
mean = np.mean(train_features, axis=0)
train_features -= mean
val_features -= mean
std = np.std(train_features, axis=0)
train_features /= std
val_features /= std
Build a binary classification model
import keras
model = keras.Sequential(
[
keras.Input(shape=train_features.shape[1:]),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(256, activation="relu"),
keras.layers.Dropout(0.3),
keras.layers.Dense(1, activation="sigmoid"),
]
)
model.summary()
Result
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ dense (Dense) │ (None, 256) │ 7,936 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_1 (Dense) │ (None, 256) │ 65,792 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout (Dropout) │ (None, 256) │ 0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_2 (Dense) │ (None, 256) │ 65,792 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dropout_1 (Dropout) │ (None, 256) │ 0 │
├─────────────────────────────────┼───────────────────────────┼────────────┤
│ dense_3 (Dense) │ (None, 1) │ 257 │
└─────────────────────────────────┴───────────────────────────┴────────────┘
Total params: 139,777 (546.00 KB)
Trainable params: 139,777 (546.00 KB)
Non-trainable params: 0 (0.00 B)
Train the model with class_weight
argument
metrics = [
keras.metrics.FalseNegatives(name="fn"),
keras.metrics.FalsePositives(name="fp"),
keras.metrics.TrueNegatives(name="tn"),
keras.metrics.TruePositives(name="tp"),
keras.metrics.Precision(name="precision"),
keras.metrics.Recall(name="recall"),
]
model.compile(
optimizer=keras.optimizers.Adam(1e-2), loss="binary_crossentropy", metrics=metrics
)
callbacks = [keras.callbacks.ModelCheckpoint("fraud_model_at_epoch_{epoch}.keras")]
class_weight = {0: weight_for_0, 1: weight_for_1}
model.fit(
train_features,
train_targets,
batch_size=2048,
epochs=30,
verbose=2,
callbacks=callbacks,
validation_data=(val_features, val_targets),
class_weight=class_weight,
)
Result
Epoch 1/30
112/112 - 3s - 24ms/step - fn: 39.0000 - fp: 25593.0000 - loss: 2.2586e-06 - precision: 0.0146 - recall: 0.9065 - tn: 201836.0000 - tp: 378.0000 - val_fn: 5.0000 - val_fp: 3430.0000 - val_loss: 0.1872 - val_precision: 0.0200 - val_recall: 0.9333 - val_tn: 53456.0000 - val_tp: 70.0000
Epoch 2/30
112/112 - 0s - 991us/step - fn: 32.0000 - fp: 7936.0000 - loss: 1.5505e-06 - precision: 0.0463 - recall: 0.9233 - tn: 219493.0000 - tp: 385.0000 - val_fn: 7.0000 - val_fp: 2351.0000 - val_loss: 0.1930 - val_precision: 0.0281 - val_recall: 0.9067 - val_tn: 54535.0000 - val_tp: 68.0000
Epoch 3/30
112/112 - 0s - 1ms/step - fn: 31.0000 - fp: 6716.0000 - loss: 1.2987e-06 - precision: 0.0544 - recall: 0.9257 - tn: 220713.0000 - tp: 386.0000 - val_fn: 4.0000 - val_fp: 3374.0000 - val_loss: 0.1781 - val_precision: 0.0206 - val_recall: 0.9467 - val_tn: 53512.0000 - val_tp: 71.0000
Epoch 4/30
112/112 - 0s - 1ms/step - fn: 25.0000 - fp: 7348.0000 - loss: 1.1292e-06 - precision: 0.0506 - recall: 0.9400 - tn: 220081.0000 - tp: 392.0000 - val_fn: 6.0000 - val_fp: 1405.0000 - val_loss: 0.0796 - val_precision: 0.0468 - val_recall: 0.9200 - val_tn: 55481.0000 - val_tp: 69.0000
Epoch 5/30
112/112 - 0s - 926us/step - fn: 19.0000 - fp: 6720.0000 - loss: 8.0334e-07 - precision: 0.0559 - recall: 0.9544 - tn: 220709.0000 - tp: 398.0000 - val_fn: 11.0000 - val_fp: 315.0000 - val_loss: 0.0212 - val_precision: 0.1689 - val_recall: 0.8533 - val_tn: 56571.0000 - val_tp: 64.0000
Epoch 6/30
112/112 - 0s - 1ms/step - fn: 19.0000 - fp: 6706.0000 - loss: 8.6899e-07 - precision: 0.0560 - recall: 0.9544 - tn: 220723.0000 - tp: 398.0000 - val_fn: 8.0000 - val_fp: 1262.0000 - val_loss: 0.0801 - val_precision: 0.0504 - val_recall: 0.8933 - val_tn: 55624.0000 - val_tp: 67.0000
Epoch 7/30
112/112 - 0s - 1ms/step - fn: 15.0000 - fp: 5161.0000 - loss: 6.5298e-07 - precision: 0.0723 - recall: 0.9640 - tn: 222268.0000 - tp: 402.0000 - val_fn: 7.0000 - val_fp: 1157.0000 - val_loss: 0.0623 - val_precision: 0.0555 - val_recall: 0.9067 - val_tn: 55729.0000 - val_tp: 68.0000
Epoch 8/30
112/112 - 0s - 1ms/step - fn: 11.0000 - fp: 6381.0000 - loss: 6.7164e-07 - precision: 0.0598 - recall: 0.9736 - tn: 221048.0000 - tp: 406.0000 - val_fn: 10.0000 - val_fp: 346.0000 - val_loss: 0.0270 - val_precision: 0.1582 - val_recall: 0.8667 - val_tn: 56540.0000 - val_tp: 65.0000
Epoch 9/30
112/112 - 0s - 1ms/step - fn: 16.0000 - fp: 7259.0000 - loss: 8.9098e-07 - precision: 0.0523 - recall: 0.9616 - tn: 220170.0000 - tp: 401.0000 - val_fn: 7.0000 - val_fp: 1998.0000 - val_loss: 0.1073 - val_precision: 0.0329 - val_recall: 0.9067 - val_tn: 54888.0000 - val_tp: 68.0000
Epoch 10/30
112/112 - 0s - 999us/step - fn: 19.0000 - fp: 7792.0000 - loss: 9.2179e-07 - precision: 0.0486 - recall: 0.9544 - tn: 219637.0000 - tp: 398.0000 - val_fn: 7.0000 - val_fp: 1515.0000 - val_loss: 0.0800 - val_precision: 0.0430 - val_recall: 0.9067 - val_tn: 55371.0000 - val_tp: 68.0000
Epoch 11/30
112/112 - 0s - 1ms/step - fn: 13.0000 - fp: 5828.0000 - loss: 6.4193e-07 - precision: 0.0648 - recall: 0.9688 - tn: 221601.0000 - tp: 404.0000 - val_fn: 9.0000 - val_fp: 794.0000 - val_loss: 0.0410 - val_precision: 0.0767 - val_recall: 0.8800 - val_tn: 56092.0000 - val_tp: 66.0000
Epoch 12/30
112/112 - 0s - 959us/step - fn: 10.0000 - fp: 6400.0000 - loss: 7.4358e-07 - precision: 0.0598 - recall: 0.9760 - tn: 221029.0000 - tp: 407.0000 - val_fn: 8.0000 - val_fp: 593.0000 - val_loss: 0.0466 - val_precision: 0.1015 - val_recall: 0.8933 - val_tn: 56293.0000 - val_tp: 67.0000
Epoch 13/30
112/112 - 0s - 913us/step - fn: 9.0000 - fp: 5756.0000 - loss: 6.8158e-07 - precision: 0.0662 - recall: 0.9784 - tn: 221673.0000 - tp: 408.0000 - val_fn: 11.0000 - val_fp: 280.0000 - val_loss: 0.0336 - val_precision: 0.1860 - val_recall: 0.8533 - val_tn: 56606.0000 - val_tp: 64.0000
Epoch 14/30
112/112 - 0s - 960us/step - fn: 13.0000 - fp: 6699.0000 - loss: 1.0667e-06 - precision: 0.0569 - recall: 0.9688 - tn: 220730.0000 - tp: 404.0000 - val_fn: 9.0000 - val_fp: 1165.0000 - val_loss: 0.0885 - val_precision: 0.0536 - val_recall: 0.8800 - val_tn: 55721.0000 - val_tp: 66.0000
Epoch 15/30
112/112 - 0s - 1ms/step - fn: 15.0000 - fp: 6705.0000 - loss: 6.8100e-07 - precision: 0.0566 - recall: 0.9640 - tn: 220724.0000 - tp: 402.0000 - val_fn: 10.0000 - val_fp: 750.0000 - val_loss: 0.0367 - val_precision: 0.0798 - val_recall: 0.8667 - val_tn: 56136.0000 - val_tp: 65.0000
Epoch 16/30
112/112 - 0s - 1ms/step - fn: 8.0000 - fp: 4288.0000 - loss: 4.1541e-07 - precision: 0.0871 - recall: 0.9808 - tn: 223141.0000 - tp: 409.0000 - val_fn: 11.0000 - val_fp: 351.0000 - val_loss: 0.0199 - val_precision: 0.1542 - val_recall: 0.8533 - val_tn: 56535.0000 - val_tp: 64.0000
Epoch 17/30
112/112 - 0s - 949us/step - fn: 8.0000 - fp: 4598.0000 - loss: 4.3510e-07 - precision: 0.0817 - recall: 0.9808 - tn: 222831.0000 - tp: 409.0000 - val_fn: 10.0000 - val_fp: 688.0000 - val_loss: 0.0296 - val_precision: 0.0863 - val_recall: 0.8667 - val_tn: 56198.0000 - val_tp: 65.0000
Epoch 18/30
112/112 - 0s - 946us/step - fn: 7.0000 - fp: 5544.0000 - loss: 4.6239e-07 - precision: 0.0689 - recall: 0.9832 - tn: 221885.0000 - tp: 410.0000 - val_fn: 8.0000 - val_fp: 444.0000 - val_loss: 0.0260 - val_precision: 0.1311 - val_recall: 0.8933 - val_tn: 56442.0000 - val_tp: 67.0000
Epoch 19/30
112/112 - 0s - 972us/step - fn: 3.0000 - fp: 2920.0000 - loss: 2.7543e-07 - precision: 0.1242 - recall: 0.9928 - tn: 224509.0000 - tp: 414.0000 - val_fn: 9.0000 - val_fp: 510.0000 - val_loss: 0.0245 - val_precision: 0.1146 - val_recall: 0.8800 - val_tn: 56376.0000 - val_tp: 66.0000
Epoch 20/30
112/112 - 0s - 1ms/step - fn: 6.0000 - fp: 5351.0000 - loss: 5.7495e-07 - precision: 0.0713 - recall: 0.9856 - tn: 222078.0000 - tp: 411.0000 - val_fn: 9.0000 - val_fp: 547.0000 - val_loss: 0.0255 - val_precision: 0.1077 - val_recall: 0.8800 - val_tn: 56339.0000 - val_tp: 66.0000
Epoch 21/30
112/112 - 0s - 1ms/step - fn: 6.0000 - fp: 3808.0000 - loss: 5.1475e-07 - precision: 0.0974 - recall: 0.9856 - tn: 223621.0000 - tp: 411.0000 - val_fn: 10.0000 - val_fp: 624.0000 - val_loss: 0.0320 - val_precision: 0.0943 - val_recall: 0.8667 - val_tn: 56262.0000 - val_tp: 65.0000
Epoch 22/30
112/112 - 0s - 1ms/step - fn: 6.0000 - fp: 5117.0000 - loss: 5.5465e-07 - precision: 0.0743 - recall: 0.9856 - tn: 222312.0000 - tp: 411.0000 - val_fn: 10.0000 - val_fp: 836.0000 - val_loss: 0.0556 - val_precision: 0.0721 - val_recall: 0.8667 - val_tn: 56050.0000 - val_tp: 65.0000
Epoch 23/30
112/112 - 0s - 939us/step - fn: 8.0000 - fp: 5583.0000 - loss: 5.5407e-07 - precision: 0.0683 - recall: 0.9808 - tn: 221846.0000 - tp: 409.0000 - val_fn: 12.0000 - val_fp: 501.0000 - val_loss: 0.0300 - val_precision: 0.1117 - val_recall: 0.8400 - val_tn: 56385.0000 - val_tp: 63.0000
Epoch 24/30
112/112 - 0s - 958us/step - fn: 5.0000 - fp: 3933.0000 - loss: 4.7133e-07 - precision: 0.0948 - recall: 0.9880 - tn: 223496.0000 - tp: 412.0000 - val_fn: 12.0000 - val_fp: 211.0000 - val_loss: 0.0326 - val_precision: 0.2299 - val_recall: 0.8400 - val_tn: 56675.0000 - val_tp: 63.0000
Epoch 25/30
112/112 - 0s - 1ms/step - fn: 7.0000 - fp: 5695.0000 - loss: 7.1277e-07 - precision: 0.0672 - recall: 0.9832 - tn: 221734.0000 - tp: 410.0000 - val_fn: 9.0000 - val_fp: 802.0000 - val_loss: 0.0598 - val_precision: 0.0760 - val_recall: 0.8800 - val_tn: 56084.0000 - val_tp: 66.0000
Epoch 26/30
112/112 - 0s - 949us/step - fn: 5.0000 - fp: 3853.0000 - loss: 4.1797e-07 - precision: 0.0966 - recall: 0.9880 - tn: 223576.0000 - tp: 412.0000 - val_fn: 8.0000 - val_fp: 771.0000 - val_loss: 0.0409 - val_precision: 0.0800 - val_recall: 0.8933 - val_tn: 56115.0000 - val_tp: 67.0000
Epoch 27/30
112/112 - 0s - 947us/step - fn: 4.0000 - fp: 3873.0000 - loss: 3.7369e-07 - precision: 0.0964 - recall: 0.9904 - tn: 223556.0000 - tp: 413.0000 - val_fn: 6.0000 - val_fp: 2208.0000 - val_loss: 0.1370 - val_precision: 0.0303 - val_recall: 0.9200 - val_tn: 54678.0000 - val_tp: 69.0000
Epoch 28/30
112/112 - 0s - 892us/step - fn: 5.0000 - fp: 4619.0000 - loss: 4.1290e-07 - precision: 0.0819 - recall: 0.9880 - tn: 222810.0000 - tp: 412.0000 - val_fn: 8.0000 - val_fp: 551.0000 - val_loss: 0.0273 - val_precision: 0.1084 - val_recall: 0.8933 - val_tn: 56335.0000 - val_tp: 67.0000
Epoch 29/30
112/112 - 0s - 931us/step - fn: 1.0000 - fp: 3336.0000 - loss: 2.5478e-07 - precision: 0.1109 - recall: 0.9976 - tn: 224093.0000 - tp: 416.0000 - val_fn: 9.0000 - val_fp: 487.0000 - val_loss: 0.0238 - val_precision: 0.1193 - val_recall: 0.8800 - val_tn: 56399.0000 - val_tp: 66.0000
Epoch 30/30
112/112 - 0s - 1ms/step - fn: 2.0000 - fp: 3521.0000 - loss: 4.1991e-07 - precision: 0.1054 - recall: 0.9952 - tn: 223908.0000 - tp: 415.0000 - val_fn: 10.0000 - val_fp: 462.0000 - val_loss: 0.0331 - val_precision: 0.1233 - val_recall: 0.8667 - val_tn: 56424.0000 - val_tp: 65.0000
<keras.src.callbacks.history.History at 0x7f22b41f3430>
Conclusions
At the end of training, out of 56,961 validation transactions, we are:
- Correctly identifying 66 of them as fraudulent
- Missing 9 fraudulent transactions
- At the cost of incorrectly flagging 441 legitimate transactions
In the real world, one would put an even higher weight on class 1, so as to reflect that False Negatives are more costly than False Positives.
Next time your credit card gets declined in an online purchase – this is why.
Example available on HuggingFace.