Skip to content

How to sample and save an image for afhq #2

@cutoken

Description

@cutoken

Hi @robert-graf,
Thank you for this wonderful reference implementation. I'm trying to infer and save the sampled image for afhq dataset. Below is my code adapted from your mnist inference code. Somehow the saved image is just blank/empty:

import torch, torchvision
from loader.arguments import get_latest_Checkpoint
from diffusion import Diffusion
from loader import load_dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

name = "afhq"  # exp_name
version = "*"  # * is the newest or use the number for an specific version
batch_size = 3
checkpoint = get_latest_Checkpoint(name, log_dir_name="logs_diffusion", best=False, version="*")
assert checkpoint is not None, "did not found checkpoint"
model = Diffusion.load_from_checkpoint(checkpoint, strict=False)
#model.cpu()#cuda()
model.cuda()

label = torch.tensor([0,1,2]).cuda()

# DDPM
image_ddpm = model.forward(batch_size, 1000, label=label)
# DDIM
image_ddim, inter = model.forward_ddim(batch_size, [i for i in range(0, 1000, 20)], label=label, eta=0.0)  # type: ignore
assert isinstance(image_ddpm, torch.Tensor)
grid = torchvision.utils.make_grid(torch.cat([image_ddpm, image_ddim], dim=0), nrow=5).cpu()
a = grid.permute(1, 2, 0)
a = torch.clamp(a, 0, 1)
plt.figure(figsize=(40, 40))
plt.savefig('sample.jpg')

Could you let me know if I'm missing something here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions