VAE for Motion Sequence Generation - Convergence Issue with Scheduled Sampling

I have implemented a Variational Autoencoder (VAE) in PyTorch for motion sequence generation using human pose data (joint angles and angular velocities in radians) from the CMU dataset. The VAE architecture consists of an encoder and a decoder, each with two layers, comprised of a Conv1D layer and an ELU activation for each layer.

During training, I input a sequence of 121 poses (60 prev pose + current pose (p(n)) + 60 next pose in the dataset) and the VAE generates the next pose (p_hat(n+1).

I also have tried with normalized joint angles and angular velocities but it worsens the convergence.

Here’s an overview of my training process:

Loss Function:

Initially trained for 30 epochs using Mean Squared Error (MSE) loss by comparing the generated next pose with ground truth data from the CMU dataset.

loss = MSE(p(n+1), p_hat(n+1))

From epoch 31 to 60, I added the KL divergence to the loss function.

loss = MSE(p(n+1), p_hat(n+1)) + KL

Scheduled Sampling:

Starting from epoch 61, I applied scheduled sampling, gradually increasing the probability p from 0.0 to 1.0 over 20 epochs (epoch 61 to 80).
From epoch 81 onwards, p is set to 1, implying that the generated next pose is fed into the model as the current pose to generate the next pose.
The length of scheduled sampling is 8 (I autoregressively create next 8 poses inputing the generated pose of the VAE)

The Issue:
The network converges nicely on the MSE loss, a bit slower on MSE+KL, but it fails to converge when scheduled sampling is applied.

My Questions:

Is there a potential reason why the model doesn’t converge during the scheduled sampling phase?
Are there any adjustments or insights regarding the VAE structure or training parameters that could help resolve this issue and improve convergence during scheduled sampling?

Any insights or guidance would be greatly appreciated. Thank you in advance!

Here is the model and the parameters:

class Encoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, frameSequence, intermediate_channels):
        super(VariationalEncoder, self).__init__()
        
        #intermediate_channels = 256
        # layer 1
        self.convLayer1 = nn.Conv1d(in_channels = inputFeatDim,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        # layer 2
        self.convLayer2 = nn.Conv1d(in_channels = intermediate_channels + inputFeatDim, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        self.downSamepleLayer = nn.Linear(in_features= frameSequence, out_features=1, bias=True)

        self.muLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        self.logVarLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        
        self.normalDist = torch.distributions.Normal(0, 1)
        self.normalDist.loc = self.normalDist.loc.cuda()
        self.normalDist.scale = self.normalDist.scale.cuda()
        self.kullbackLeibler = 0
        self.latent = torch.zeros(1).cuda()

        
        #self.print_f = True
        

    def forward(self, x):
        input = x
        x = self.convLayer1(x)
        l1_output = x
        x = torch.relu(x)
        
        x = self.convLayer2(torch.cat((input, x),dim=1))
        x = torch.relu(x)
        
        x = self.downSamepleLayer(x)
        
        mu = self.muLayer(x) # input here must be(latentDim)
        logVar= self.logVarLayer(x)
        
        self.latent = mu + torch.exp(0.5 * logVar)*self.normalDist.sample(mu.shape)
        self.kullbackLeibler = ((torch.exp(logVar) + mu**2)/2 - 0.5 * logVar - 0.5).sum()/(logVar.size()[0]) # logVar size ----> [batch_size * latentDim * 1]
        return self.latent, self.kullbackLeibler


class Decoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, poseFeatDim, frameSequence, intermediate_channels):
        super(Decoder, self).__init__()
        self.LatentExpander = nn.Linear(in_features=latentDim, out_features=poseFeatDim)

        # entry layer
        entry_in_channels = latentDim + poseFeatDim
        self.entryLayer = nn.Conv1d(in_channels = entry_in_channels, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)

        # hidden layer 1
        self.convLayer1 = nn.Conv1d(in_channels = intermediate_channels+entry_in_channels,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        
         
    def forward(self, latent, cur_pose):
        
        cur_pose = cur_pose.unsqueeze(2)
        
        x = torch.cat([latent, cur_pose], dim = 1)
        input = x
        

        x = self.entryLayer(x)
        x = torch.relu(x)
        
        
        x = self.convLayer1(torch.cat((input, x), dim=1))
        x = torch.relu(x)
        

        x = self.finalLayer(x)
        return x
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.encoder = Encoder
        self.decoder = Decoder

    def forward(self, seq, cur_pose):
        latent, kullbackLeibler = self.encoder(seq)
        
        X_hat = self.decoder(latent, cur_pose)
        return X_hat, latent, kullbackLeibler

Here are the parameters:

args = {
        data_unit: radian,
        dataset: cmu_dataset,
        frameSequence: 121,
        latentDim: 10,
        inputFeatDim: 64,
        poseFeatDim: 64,
        lr: 1e-4,
        lr_init: 1e-4,
        lr_final: 1e-7,
        intermediate_channels: 256,
        epochs_before_KL: 30,
        batch_size: 256,
        epochs: 140,
        load_saved_model: False,
        scheduled_sampling_length: 8,
        first_train_stage_epochs: 60, # Before exposing model to scheduled sampling
        second_train_stage_epochs : 20 # uniformly introducing scheduled sampling until epoch = 80 and in the rest autoregressively generate new poses for 8 consequent poses
        }
def train(VAE, data, device, optimizer, load_saved_model, epochs_before_KL, draw_pose, poseFeatDim, scheduled_sampling_length, epochs, lr, lr_init, lr_final, latentDim, frameSequence, first_train_stage_epochs, second_train_stage_epochs):
    N = int((frameSequence-1)/2) # pose sequences before and after current pose
    KL_accum = 0
    loss_accum = 0
    
    epoch_loss_log =[]
    if load_saved_model==True:
        alpha = 1.0
    else:
        alpha = 0.0
    latent = 0
    cur_pose =0
    print("number of batches: " + str(len(data)))
    for epoch in range(epochs):
        epoch_KL = 0
        if epoch == epochs_before_KL:
            alpha = 1.0

        epoch_loss= 0
        cur_frame_idx = 0
        X_hat = 0


        if epoch > (first_train_stage_epochs+second_train_stage_epochs-1):
            lr = lr_init - (lr_init-lr_final)*(epoch-(first_train_stage_epochs+second_train_stage_epochs))/(epochs-(first_train_stage_epochs+second_train_stage_epochs))
        optimizer = torch.optim.Adam(VAE.parameters(), lr =lr)
        for X, target, cur_frame_idx in data:
            KL_accum = 0
            loss_accum = 0
            
            #X = X.to(device) # GPU    
            X = torch.permute(X, (0, 2,1)).type(torch.FloatTensor).to(device=device)
            
            X_hat = (X[:,:,N].clone()).unsqueeze(2).cuda()#to(device=device)
            l=0
            for l in range(scheduled_sampling_length):
                train_loss = 0
                optimizer.zero_grad()
                
                
                cur_X = (X[:,:,l:frameSequence+l].clone()).to(device=device)
                
                cur_pose = (cur_X[:, 0:poseFeatDim, N].clone()).to(device=device)
                
                GT = (cur_X[:,0:poseFeatDim, N+1].clone()).unsqueeze(2).cuda()
                
                # scheduled sampling
                if load_saved_model==True:
                    p=0
                elif epoch<first_train_stage_epochs:
                    p = 1
                elif epoch<first_train_stage_epochs+second_train_stage_epochs:
                    w1 = (epoch - first_train_stage_epochs+1)/second_train_stage_epochs
                    w2= 1-w1
                    
                    weights = torch.tensor([w1, w2], dtype=torch.float).cuda()
                    p = torch.multinomial(weights, 1, replacement=True).cuda().item()

                else:
                    p = 0

                input = p * cur_pose.detach().cuda() + (1-p)* X_hat[:,0:poseFeatDim,:].detach().squeeze(2).cuda()
                
                X_hat, latent, KL_loss = VAE(cur_X, input)
                KL_accum += KL_loss
                loss_accum +=(((GT - X_hat[:,0:poseFeatDim,:]).clone()**2).sum())/(GT.size(dim=0))
                
                recon_loss = (((GT - X_hat[:,0:poseFeatDim,:])**2).sum())/(GT.size(dim=0)) # GT shape ----> [batch_size, inputFeatDim, 1]
                train_loss = recon_loss + KL_loss * alpha
                    
                train_loss.backward()
                    
                epoch_loss += train_loss
                epoch_KL  += KL_loss
                optimizer.step()
            
        print("epoch: " + str(epoch)+" loss: " + str(epoch_loss.item()) + " KL: " + str(epoch_KL.item() * alpha))
        epoch_loss_log.append(epoch_loss)
    
    return VAE, X, latent[0,:,:].unsqueeze(0), X[0,:,N].unsqueeze(0), epoch_loss, optimizer