Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Dit is een simpele implementatie van een convolutioneel neural netwerk voor de Fashion-MNIST dataset van Zalando. De data bestaan uit 70k grijswaardebeelden met een resolutie van 28 bij 28 pixels. Ze behoren tot volgende 10 categorieรซn:

  1. T-shirt/top

  2. Trouser

  3. Pullover

  4. Dress

  5. Coat

  6. Sandal

  7. Shirt

  8. Sneaker

  9. Bag

  10. Ankle boot

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import torch
from plotly.subplots import make_subplots
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_fscore_support
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToPILImage, ToTensor

# Class names
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
Using device: cpu
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
print(f"Training set size: {len(training_data)}")
print(f"Test set size: {len(test_data)}")

sample_img, sample_label = training_data[0]
print(f"Sample image shape: {sample_img.shape}")
print(f"Label: {sample_label}")

ToPILImage()(sample_img)
Training set size: 60000
Test set size: 10000
Sample image shape: torch.Size([1, 28, 28])
Label: 9
<PIL.Image.Image image mode=L size=28x28>
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [Batch Size, Channels, Height, Width]:{X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    print(f"Unique values of y: {torch.unique(y)}")
    break
Shape of X [Batch Size, Channels, Height ,Width]:torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Unique values of y: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# Analyze class distribution
train_labels = [label for _, label in training_data]
test_labels = [label for _, label in test_data]

# Count labels
train_counts = np.bincount(train_labels)
test_counts = np.bincount(test_labels)

# Create grouped bar chart
fig = go.Figure()

fig.add_trace(
    go.Bar(
        name="Training Set",
        x=classes,
        y=train_counts,
        text=train_counts,
        textposition="auto",
        marker_color="lightblue",
    )
)

fig.add_trace(
    go.Bar(
        name="Test Set",
        x=classes,
        y=test_counts,
        text=test_counts,
        textposition="auto",
        marker_color="lightcoral",
    )
)

fig.update_layout(
    title="Class Distribution in Training and Test Sets",
    xaxis_title="Class",
    yaxis_title="Number of Samples",
    barmode="group",
    width=1000,
    height=500,
)

fig.show()

# Print statistics
print("Training set distribution:")
for i, (class_name, count) in enumerate(zip(classes, train_counts)):
    percentage = (count / len(training_data)) * 100
    print(f"{class_name:15} {count:5d} ({percentage:.1f}%)")

print(f"\nTotal training samples: {len(training_data)}")
print(f"Total test samples: {len(test_data)}")
Loading...
Training set distribution:
T-shirt/top      6000 (10.0%)
Trouser          6000 (10.0%)
Pullover         6000 (10.0%)
Dress            6000 (10.0%)
Coat             6000 (10.0%)
Sandal           6000 (10.0%)
Shirt            6000 (10.0%)
Sneaker          6000 (10.0%)
Bag              6000 (10.0%)
Ankle boot       6000 (10.0%)

Total training samples: 60000
Total test samples: 10000
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        # Convolutional layers
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),  # 28x28x32
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14x32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 14x14x64
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7x64
        )
        self.flatten = nn.Flatten()
        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Linear(7 * 7 * 64, 128), nn.ReLU(), nn.Dropout(0.5), nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        logits = self.fc_layers(x)
        return logits


model = NeuralNetwork().to(device)
print(model)
print(
    f"Total number of trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
NeuralNetwork(
  (conv_layers): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc_layers): Sequential(
    (0): Linear(in_features=3136, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=10, bias=True)
  )
)
Total number of trainable parameters: 421642
# Loss function
loss_fn = nn.CrossEntropyLoss()

# Stochastic gradient descent
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error:\n Accuracy: {(100 * correct):>0.1f}%, Avg loss:{test_loss:>8f} \n")
epochs = 5

for t in range(epochs):
    print(f"Epoch {t + 1}\n------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1
------------
loss:2.300693 [   64/60000]
loss:2.296715 [ 6464/60000]
loss:2.292114 [12864/60000]
loss:2.296987 [19264/60000]
loss:2.287122 [25664/60000]
loss:2.296292 [32064/60000]
loss:2.275176 [38464/60000]
loss:2.281695 [44864/60000]
loss:2.272089 [51264/60000]
loss:2.269891 [57664/60000]
Test Error:
 Accuracy: 32.9%, Avg loss:2.262706 

Epoch 2
------------
loss:2.260585 [   64/60000]
loss:2.266152 [ 6464/60000]
loss:2.244620 [12864/60000]
loss:2.257603 [19264/60000]
loss:2.221899 [25664/60000]
loss:2.232457 [32064/60000]
loss:2.185710 [38464/60000]
loss:2.191071 [44864/60000]
loss:2.138950 [51264/60000]
loss:2.141331 [57664/60000]
Test Error:
 Accuracy: 60.5%, Avg loss:2.120909 

Epoch 3
------------
loss:2.143372 [   64/60000]
loss:2.093894 [ 6464/60000]
loss:2.035286 [12864/60000]
loss:2.015930 [19264/60000]
loss:1.920137 [25664/60000]
loss:1.847994 [32064/60000]
loss:1.763599 [38464/60000]
loss:1.740850 [44864/60000]
loss:1.695572 [51264/60000]
loss:1.518745 [57664/60000]
Test Error:
 Accuracy: 61.5%, Avg loss:1.457181 

Epoch 4
------------
loss:1.577464 [   64/60000]
loss:1.530314 [ 6464/60000]
loss:1.369598 [12864/60000]
loss:1.366476 [19264/60000]
loss:1.176187 [25664/60000]
loss:1.276275 [32064/60000]
loss:1.235990 [38464/60000]
loss:1.147604 [44864/60000]
loss:1.251951 [51264/60000]
loss:1.101347 [57664/60000]
Test Error:
 Accuracy: 65.7%, Avg loss:1.025680 

Epoch 5
------------
loss:1.195766 [   64/60000]
loss:1.140486 [ 6464/60000]
loss:1.020737 [12864/60000]
loss:1.140923 [19264/60000]
loss:1.015276 [25664/60000]
loss:1.144631 [32064/60000]
loss:1.056138 [38464/60000]
loss:1.021924 [44864/60000]
loss:1.052670 [51264/60000]
loss:1.067314 [57664/60000]
Test Error:
 Accuracy: 70.0%, Avg loss:0.868605 

Done!
torch.save(model.state_dict(), "zalando_cnn.pth")
# load the trained weights into a new instance of the model
model = NeuralNetwork().to(device)
model.load_state_dict(torch.load("zalando_cnn.pth"))
model.eval()
NeuralNetwork( (conv_layers): Sequential( (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): ReLU() (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (4): ReLU() (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) ) (flatten): Flatten(start_dim=1, end_dim=-1) (fc_layers): Sequential( (0): Linear(in_features=3136, out_features=128, bias=True) (1): ReLU() (2): Dropout(p=0.5, inplace=False) (3): Linear(in_features=128, out_features=10, bias=True) ) )
# Get predictions for entire test set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for X, y in test_dataloader:
        X = X.to(device)
        pred = model(X)
        all_preds.extend(pred.argmax(1).cpu().numpy())
        all_labels.extend(y.numpy())

all_preds = np.array(all_preds)
all_labels = np.array(all_labels)

# Create confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Plot confusion matrix
fig = go.Figure(
    data=go.Heatmap(
        z=cm,
        x=classes,
        y=classes,
        colorscale="Blues",
        text=cm,
        texttemplate="%{text}",
        textfont={"size": 10},
        hoverongaps=False,
    )
)

fig.update_layout(
    title="Confusion Matrix",
    xaxis_title="Predicted Label",
    yaxis_title="True Label",
    width=800,
    height=700,
)

fig.show()
Loading...
precision, recall, f1, support = precision_recall_fscore_support(
    all_labels, all_preds, average=None
)

# Calculate per-class accuracy
class_correct = np.diag(cm)
class_total = cm.sum(axis=1)
accuracy = class_correct / class_total

# Create grouped bar chart
fig = go.Figure()

fig.add_trace(
    go.Bar(
        name="Accuracy",
        x=classes,
        y=accuracy,
        text=[f"{a:.3f}" for a in accuracy],
        textposition="auto",
    )
)

fig.add_trace(
    go.Bar(
        name="Precision",
        x=classes,
        y=precision,
        text=[f"{p:.3f}" for p in precision],
        textposition="auto",
    )
)

fig.add_trace(
    go.Bar(
        name="Recall",
        x=classes,
        y=recall,
        text=[f"{r:.3f}" for r in recall],
        textposition="auto",
    )
)

fig.add_trace(
    go.Bar(
        name="F1-Score",
        x=classes,
        y=f1,
        text=[f"{f:.3f}" for f in f1],
        textposition="auto",
    )
)

fig.update_layout(
    title="Classification Metrics by Class",
    xaxis_title="Class",
    yaxis_title="Score",
    yaxis_range=[0, 1],
    barmode="group",
    width=1000,
    height=500,
)

fig.show()

# Print overall metrics
overall_accuracy = (all_preds == all_labels).mean()
print(f"\nOverall Accuracy: {overall_accuracy:.2%}")
print(f"Average Precision: {precision.mean():.3f}")
print(f"Average Recall: {recall.mean():.3f}")
print(f"Average F1-Score: {f1.mean():.3f}")
Loading...

Overall Accuracy: 70.00%
Average Precision: 0.694
Average Recall: 0.700
Average F1-Score: 0.679
# Find most confused pairs (excluding diagonal)
cm_no_diag = cm.copy()
np.fill_diagonal(cm_no_diag, 0)

# Get top 10 confused pairs
confusion_pairs = [
    {"true": classes[i], "predicted": classes[j], "count": cm_no_diag[i, j]}
    for i in range(len(classes))
    for j in range(len(classes))
    if i != j and cm_no_diag[i, j] > 0
]

confusion_pairs = sorted(confusion_pairs, key=lambda x: x["count"], reverse=True)[:10]

# Create horizontal bar chart
fig = go.Figure(
    data=[
        go.Bar(
            y=[f"{p['true']} โ†’ {p['predicted']}" for p in confusion_pairs],
            x=[p["count"] for p in confusion_pairs],
            orientation="h",
            text=[p["count"] for p in confusion_pairs],
            textposition="auto",
            marker_color="coral",
        )
    ]
)

fig.update_layout(
    title="Top 10 Most Confused Class Pairs",
    xaxis_title="Number of Misclassifications",
    yaxis_title="True โ†’ Predicted",
    width=900,
    height=600,
    yaxis={"categoryorder": "total ascending"},
)

fig.show()
Loading...