Hi there! I am just wondering if someone more experienced could take a look at this setup and tell me if the gradients will flow properly across both the unet and the Cond_First_Stage essentially attached to the model.
I am using the Accelerate library for distributed training with 10 nodes and 4 GPUS per node. I want both the unet model weights to be trained and the conditional first stage weights to be trained (effectively aligning the image embeddings with an unconventional data input type to the model)
Happy to share more of the code (i.e. the training script) to provide a better understanding of the workflow.
Thanks in advance!!!
Below is code:
import torch
from torch import nn as nn
from transformers import CLIPVisionModelWithProjection
import torch.nn.functional as F
class WrapperModel(nn.Module):
def __init__(self, unet, vae, noise_scheduler, encoder, config, frozen_image_embedder,accelerator):
super(WrapperModel, self).__init__()
self.unet = unet
self.vae = vae
self.noise_scheduler = noise_scheduler
self.encoder = encoder
self.config = config
self.frozen_image_embedder = frozen_image_embedder
self.accelerator = accelerator
self.weight_dtype = torch.float32
# Initialize the conditional first stage within the UNet
self.unet.cond_first_stage = First_Stage(encoder=self.encoder)
self.vae.eval()
self.encoder.eval()
self.unet.train()
# Ensure the VAE and encoder are not updated during training
self.vae.requires_grad_(False)
self.encoder.requires_grad_(False)
self.frozen_image_embedder.requires_grad_(False)
# Move each model to the appropriate device
self.vae.to(self.accelerator.device, dtype=self.weight_dtype)
self.unet.to(self.accelerator.device)
self.encoder.to(self.accelerator.device)
self.frozen_image_embedder.to(self.accelerator.device)
def forward(self, batch,accelerator):
image_embedding = self.vae.encode(batch['pixel_values'].to(self.weight_dtype)).latent_dist.sample()
image_embedding = image_embedding * self.vae.config.scaling_factor
# Generate initial noise matching the dimensions of the image embeddings
noise = torch.randn_like(image_embedding)
# Add a noise offset to introduce variability
noise += self.config.noise_offset * torch.randn((image_embedding.shape[0], image_embedding.shape[1], 1, 1), device=image_embedding.device)
# Introduce additional perturbation to the noise for robustness
new_noise = noise + self.config.input_perturbation * torch.randn_like(noise)
# Determine batch size for processing
bsz = image_embedding.shape[0]
# Randomly sample timesteps for each image in the batch for the diffusion process
timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embedding.device).long()
# Add noise to the image embeddings at the sampled timesteps, simulating the forward diffusion process
noisy_latents = self.noise_scheduler.add_noise(image_embedding, new_noise, timesteps)
# Obtain EEG embeddings by processing EEG data through the conditional first stage of the UNet
encoder_hidden_states, ___ = self.unet.cond_first_stage(batch['eeg'].to(accelerator.device))
# Define the target for the training as the original noise
target = noise
# Get the model's prediction for the noisy latents and EEG embeddings
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
# Compute CLIP loss (assumes a method get_clip_loss exists within First_Stage)
img_embd = self.frozen_image_embedder(batch['pixel_values'].to(accelerator.device))
clip_loss = self.unet.cond_first_stage.get_clip(batch['eeg'], img_embd)
return model_pred, clip_loss, timesteps,target
class Dim_Mapper(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(128, 1, 1, stride=1)
# Fully connected layer to transform the feature vector
self.fc1 = nn.Linear(1024, 768) # Note: Change the first dim to the encoder last dim size.
def forward(self, x):
# Apply a convolution operation to the input
x = self.conv1(x)
# Remove unnecessary dimension after convolution
x = x.squeeze(1)
# Apply a linear transformation
x = self.fc1(x)
return x
# The point of having a seperate Dim_Mapper class is so that we can swap out using and not using clip alignmnet
class First_Stage(nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.encoder.requires_grad_(False) # Freeze the encoder weights to prevent updates during training
self.seq_len = encoder.num_patches # Get the sequence length from the encoder model
self.input_dim = 1024
self.output_dim = 768
# Dimensionality mapper for adjusting feature vector sizes
self.mapper = Dim_Mapper()
# Unet expects (batch, sequence_length, feature_dim)
# Encoder outputs (batch, sequence_length, feature_dim) -> (batch, 128, 1024)
self.conv_block = nn.Sequential(
nn.Conv1d(in_channels=self.input_dim, out_channels=128, kernel_size=3, stride=2, padding=1),
nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
nn.Conv1d(in_channels=256, out_channels=self.output_dim, kernel_size=3, stride=2, padding=1)
)
conv_seq_len = self.seq_len // (2**3)
self.fc = nn.Linear(in_features=conv_seq_len*self.output_dim , out_features=self.output_dim)
def forward(self, x):
# Encode the input using the encoder model
x = self.encoder.forward(x)
latent = x # Store the encoder output for potential use
# Rearrange input to (batch, feature_dim, sequence_length) for Conv1d
x = x.transpose(1, 2)
x = F.relu(self.conv_block(x))
x = torch.flatten(x, start_dim=1)
x = self.fc(x).unsqueeze(1)
return x, latent
def get_clip(self, x, image_embeddings):
# Map the input dimensions to align with the image embeddings
x = self.encoder.forward(x)
x = self.mapper(x)
# Calculate the CLIP loss by comparing the cosine similarity between the mapped input and image embeddings
loss = 1 - torch.nn.functional.cosine_similarity(x, image_embeddings, dim=-1).mean()
return loss