Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

Navigeer naar deze notebook op GitHub: book/ml_principles/labs/imggen_butterfly

Via bovenstaande link kan je deze notebook openen in Google Colaboratory. In die omgeving kunnen we gebruik maken van gratis quota voor GPUs (en TPUs). GPU acceleratie is hier sterk aanbevolen voor zowel model training als model inference.

import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Data

In dit labo trainen we een (mini) diffusionmodel met een dataset van foto’s van vlinders die afkomstig is van het Amerikaanse Smithsonian Institute.

De dataset is beschikbaar op Hugging Face en kan van daar geladen worden via het datasets package.

dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

Data preprocessing

We transformeren de afbeeldingen naar een uniform formaat van 32x32 pixels, voeren data augmentatie uit (horizontaal flippen) en normaliseren de pixelwaarden naar het bereik (-1, 1).

image_size = 32
batch_size = 32

preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


dataset.set_transform(transform)

# Create a dataloader from the dataset to serve up the transformed images in batches
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

Visualisatie functie

Een helper functie om batches van afbeeldingen te visualiseren als een grid.

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL."""
    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

Voorbeeld afbeeldingen

We visualiseren enkele voorbeelden uit de training dataset.

xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)
show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)

Noise scheduler

De DDPM (Denoising Diffusion Probabilistic Models) scheduler bepaalt hoe ruis gradueel wordt toegevoegd tijdens training en verwijderd tijdens generatie. We gebruiken 1000 timesteps.

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

Noise schedule visualisatie

Deze grafiek toont hoe het signaal (originele afbeelding) en de ruis zich verhouden over de verschillende timesteps. Aan het einde is de afbeelding bijna volledig ruis.

plt.plot(noise_scheduler.alphas_cumprod.cpu() ** 0.5, label=r"${\sqrt{\bar{\alpha}_t}}$")
plt.plot((1 - noise_scheduler.alphas_cumprod.cpu()) ** 0.5, label=r"$\sqrt{(1 - \bar{\alpha}_t)}$")
plt.legend(fontsize="x-large");

Alternatieve noise schedules

Deze cel toont andere mogelijke configuraties voor de noise scheduler.

# One with too little noise added:
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_start=0.001, beta_end=0.004)
# The 'cosine' schedule, which may be better for small image sizes:
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

Ruis toevoegen aan afbeeldingen

We demonstreren het forward diffusion proces: afbeeldingen krijgen steeds meer ruis naarmate de timestep hoger wordt.

timesteps = torch.linspace(0, 999, 8).long().to(device)
noise = torch.randn_like(xb)
noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)
print("Noisy X shape", noisy_xb.shape)
show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)

UNet model

We definiëren een UNet architectuur met attention layers. Dit model leert om de toegevoegde ruis te voorspellen en te verwijderen.

# Create a model
model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)
model.to(device);

Model test

We testen of het model de juiste output shape produceert voordat we beginnen met trainen.

with torch.no_grad():
    model_prediction = model(noisy_xb, timesteps).sample
model_prediction.shape

Training

Het model wordt getraind om de toegevoegde ruis te voorspellen. Voor elke batch:

  1. Voeg willekeurige ruis toe aan schone afbeeldingen

  2. Laat het model de ruis voorspellen

  3. Bereken het verschil (MSE loss) tussen echte en voorspelde ruis

  4. Update de model parameters

# Set the noise scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")

# Training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

losses = []

for epoch in range(30):
    for step, batch in enumerate(train_dataloader):
        clean_images = batch["images"].to(device)
        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()

    if (epoch + 1) % 5 == 0:
        loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
        print(f"Epoch:{epoch + 1}, loss: {loss_last_epoch}")

Training verloop

Visualisatie van de loss over tijd. Een dalende trend toont dat het model leert.

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
axs[0].plot(losses)
axs[1].plot(np.log(losses))
plt.show()

Generatie

Nu gebruiken we het getrainde model om nieuwe vlinder afbeeldingen te genereren. We starten met pure ruis en verwijderen stap voor stap de ruis volgens het reverse diffusion proces.

# Random starting point (8 random images):
sample = torch.randn(8, 3, 32, 32).to(device)

for i, t in enumerate(noise_scheduler.timesteps):
    # Get model pred
    with torch.no_grad():
        residual = model(sample, t).sample

    # Update sample with step
    sample = noise_scheduler.step(residual, t, sample).prev_sample

show_images(sample)