Diffusion conditional model from diffusion-model-class

Good afternoon everyone. I have a question about conditional diffusion model that were covered in the course diffusion-model-class in github diffusion-models-class/02_class_conditioned_diffusion_model_example.ipynb at main · huggingface/diffusion-models-class · GitHub . And I use my own dataset with 3 class marks and 9000 images
Import dataset

dataset = load_dataset("WiNE-iNEFF/label_minecraft_skin_data", split="train")

# Define data augmentations
preprocess = transforms.Compose([transforms.Resize((64, 64)), transforms.ToTensor()])

def transform(examples):
    images = [preprocess(image.convert("RGBA")) for image in examples["image"]]
    labels = [label for label in examples['label']]
    return {"images": images, "labels": labels}


dataset.set_transform(transform)

Conditional Unet - except for changing the number of input and output channels, nothing has changed

class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=3, class_emb_size=3):
    super().__init__()
    
    # The embedding layer will map the class label to a vector of size class_emb_size
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=64,           # the target image resolution
        in_channels=4 + class_emb_size, # Additional input channels for class cond.
        out_channels=4,           # the number of output channels
        layers_per_block=2,       # how many ResNet layers to use per UNet block
        block_out_channels=(64, 128, 128, 256, 512),
        down_block_types=(
            "DownBlock2D",        # a regular ResNet downsampling block
            "DownBlock2D",
            "DownBlock2D",
            "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
            "UpBlock2D",
            "UpBlock2D",
          ),
    )

  # Our forward method now takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    # Shape of x:
    bs, ch, w, h = x.shape
    
    # class conditioning in right shape to add as additional input channels
    class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
    print(class_cond)
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    print(class_cond)
    # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)

    # Net input is now x and class cond concatenated together along dimension 1
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Feed this to the unet alongside the timestep and return the prediction
    return self.model(net_input, t).sample

Training Loop

noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
noise_scheduler.set_timesteps(num_inference_steps=40)
net = ClassConditionedUnet().to(device)

train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
n_epochs = 5

# Our loss finction
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
losses = []

# The training loop
for epoch in range(n_epochs):
    wandb.log({'epoch':epoch})
    for x, y in tqdm(enumerate(train_dataloader)):
        x = y['images'].to(device) * 2 - 1 # Data on the GPU (mapped to (-1, 1))
        y = y['labels'].to(device)

        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # Get the model prediction
        pred = net(noisy_x, timesteps, y) # Note that we pass in the labels y

        # Calculate the loss
        loss = loss_fn(pred, noise) # How close is the output to the noise
        wandb.log({'loss':loss.item()})

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print our the average of the last 100 loss values to get an idea of progress:
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

Image generation

def show_images(x):
    """Given a batch of images x, make a grid and convert to PIL"""
    #x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

x = torch.randn(40, 4, 64, 64).to(device)
y = torch.tensor([[i]*4 for i in range(3)]).flatten().to(device)

# Sampling loop
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):
    # Get model pred
    with torch.no_grad():
        residual = net(x, t, y)  # Again, note that we pass in our labels y

    # Update sample with step
    x = noise_scheduler.step(residual, t, x).prev_sample

show_images(x)

Error

RuntimeError                              Traceback (most recent call last)
<ipython-input-20-e6b9df05a01b> in <module>
      9     # Get model pred
     10     with torch.no_grad():
---> 11         residual = net(x, t, y)  # Again, note that we pass in our labels y
     12 
     13     # Update sample with step

1 frames
<ipython-input-14-107df7faa9a0> in forward(self, x, t, class_labels)
     37     class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
     38     #print(class_cond)
---> 39     class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
     40     #print(class_cond)
     41     # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)

RuntimeError: shape '[40, 1, 1, 1]' is invalid for input of size 36