VAE for Motion Sequence Generation - Convergence Issue when using Scheduled Sampling

I implemented a Variational Autoencoder (VAE) in PyTorch for motion sequence generation using human pose data (joint angles and angular velocities in radians) from the CMU dataset. The VAE architecture consists of an encoder and a decoder, each with two layers, comprised of a Conv1D layer and an ReLu activation for each layer.

enter image description here

During training, I input a sequence of 121 poses (60 prev pose + current pose (p(n)) + 60 next pose in the dataset) and the VAE generates the next pose (p_hat(n+1)).

I also have tried with normalized joint angles and angular velocities but it worsens the convergence.

Here’s an overview of my training process:

Loss Function:

Initially trained for 30 epochs using Mean Squared Error (MSE) loss by comparing the generated next pose with ground truth data from the CMU dataset.
loss = MSE(p(n+1), p_hat(n+1))

From epoch 31 to 60, I added the KL divergence to the loss function.
loss = MSE(p(n+1), p_hat(n+1)) + KL

Scheduled Sampling:

Starting from epoch 61, I applied scheduled sampling, gradually increasing the probability p from 0.0 to 1.0 over 20 epochs (epoch 61 to 80).
From epoch 81 onwards, p is set to 1, implying that the generated next pose is fed into the model as the current pose to generate the next pose.
The length of scheduled sampling is 8 (I autoregressively create next 8 poses, inputting the generated pose of the VAE)

The Issue:
The network converges nicely on the MSE loss, a bit slower on MSE+KL, but it fails to converge when scheduled sampling is applied.

My Questions:

Is there a potential reason why the model doesn’t converge during the scheduled sampling phase?
Are there any adjustments or insights regarding the VAE structure or training parameters that could help resolve this issue and improve convergence during scheduled sampling?

VAE Structure and Parameters:

Encoder and Decoder: Each with two layers (Conv1D + ReLu activation)
Loss: MSE initially, then MSE+KL
Scheduled Sampling: Gradual increase of sampling probability p from 0.0 to 1.0 over epochs 61 to 80, then p set to 1 from epoch 81.


class Encoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, frameSequence, intermediate_channels):
        super(VariationalEncoder, self).__init__()
        
        #intermediate_channels = 256
        # layer 1
        self.convLayer1 = nn.Conv1d(in_channels = inputFeatDim,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        # layer 2
        self.convLayer2 = nn.Conv1d(in_channels = intermediate_channels + inputFeatDim, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        

        self.downSamepleLayer = nn.Linear(in_features= frameSequence, out_features=1, bias=True)

        self.muLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        self.logVarLayer = nn.Conv1d(in_channels=intermediate_channels, out_channels=latentDim, kernel_size=1, padding=0, padding_mode='reflect')
        
        self.normalDist = torch.distributions.Normal(0, 1)
        self.normalDist.loc = self.normalDist.loc.cuda()
        self.normalDist.scale = self.normalDist.scale.cuda()
        self.kullbackLeibler = 0
        self.latent = torch.zeros(1).cuda()

        
        #self.print_f = True
        

    def forward(self, x):
        input = x
        x = self.convLayer1(x)
        l1_output = x
        x = torch.relu(x)
        
        x = self.convLayer2(torch.cat((input, x),dim=1))
        x = torch.relu(x)
        
        x = self.downSamepleLayer(x)
        
        mu = self.muLayer(x) # input here must be(latentDim)
        logVar= self.logVarLayer(x)
        
        self.latent = mu + torch.exp(0.5 * logVar)*self.normalDist.sample(mu.shape)
        self.kullbackLeibler = ((torch.exp(logVar) + mu**2)/2 - 0.5 * logVar - 0.5).sum()/(logVar.size()[0]) # logVar size ----> [batch_size * latentDim * 1]
        return self.latent, self.kullbackLeibler

class Decoder(nn.Module):
    def __init__(self, latentDim, inputFeatDim, poseFeatDim, frameSequence, intermediate_channels):
        super(Decoder, self).__init__()
        self.LatentExpander = nn.Linear(in_features=latentDim, out_features=poseFeatDim)

        # entry layer
        entry_in_channels = latentDim + poseFeatDim
        self.entryLayer = nn.Conv1d(in_channels = entry_in_channels, 
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)

        # hidden layer 1
        self.convLayer1 = nn.Conv1d(in_channels = intermediate_channels+entry_in_channels,
                                    out_channels = intermediate_channels, 
                                    kernel_size = 1, 
                                    padding = 0, 
                                    padding_mode = 'zeros', 
                                    bias = True)
        
         
    def forward(self, latent, cur_pose):
        
        cur_pose = cur_pose.unsqueeze(2)
        
        x = torch.cat([latent, cur_pose], dim = 1)
        input = x
        

        x = self.entryLayer(x)
        x = torch.relu(x)
        
        
        x = self.convLayer1(torch.cat((input, x), dim=1))
        x = torch.relu(x)
        

        x = self.finalLayer(x)
        return x
class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(VAE, self).__init__()
        self.encoder = Encoder
        self.decoder = Decoder

    def forward(self, seq, cur_pose):
        latent, kullbackLeibler = self.encoder(seq)
        
        X_hat = self.decoder(latent, cur_pose)
        return X_hat, latent, kullbackLeibler

Here is my train function:

def train(VAE, data, device, optimizer, load_saved_model, epochs_before_KL, draw_pose, poseFeatDim, scheduled_sampling_length, epochs, lr, lr_init, lr_final, latentDim, frameSequence, first_train_stage_epochs, second_train_stage_epochs):
    N = int((frameSequence-1)/2) # pose sequences before and after current pose
    
    if load_saved_model==True:
        alpha = 1.0
    else:
        alpha = 0.0

    for epoch in range(epochs):
        epoch_KL = 0
        if epoch == epochs_before_KL:
            alpha = 1.0
        epoch_loss= 0



        if epoch > (first_train_stage_epochs+second_train_stage_epochs-1):
            lr = lr_init - (lr_init-lr_final)*(epoch-(first_train_stage_epochs+second_train_stage_epochs))/(epochs-(first_train_stage_epochs+second_train_stage_epochs))
        optimizer.lr = lr 
        for X, target, cur_frame_idx in data:
            
            X = torch.permute(X, (0, 2,1)).type(torch.FloatTensor).to(device=device)
            
            X_hat = (X[:,:,N].clone()).unsqueeze(2).cuda()#to(device=device)
            for l in range(scheduled_sampling_length):
                train_loss = 0
                optimizer.zero_grad()
                
                
                cur_X = (X[:,:,l:frameSequence+l].clone()).to(device=device)
                
                cur_pose = (cur_X[:, 0:poseFeatDim, N].clone()).to(device=device)
                
                GT = (cur_X[:,0:poseFeatDim, N+1].clone()).unsqueeze(2).cuda()
                
                # scheduled sampling
                if load_saved_model==True:
                    p=0
                elif epoch<first_train_stage_epochs:
                    p = 1
                elif epoch<first_train_stage_epochs+second_train_stage_epochs:
                    w1 = (epoch - first_train_stage_epochs+1)/second_train_stage_epochs
                    w2= 1-w1
                    
                    weights = torch.tensor([w1, w2], dtype=torch.float).cuda()
                    p = torch.multinomial(weights, 1, replacement=True).cuda().item()

                else:
                    p = 0

                input_pose = p * cur_pose.detach().cuda() + (1-p)* X_hat[:,0:poseFeatDim,:].detach().squeeze(2).cuda()
                
                X_hat, latent, KL_loss = VAE(cur_X, input_pose)
                
                
                recon_loss = (((GT - X_hat[:,0:poseFeatDim,:])**2).sum())/(GT.size(dim=0)) # GT shape ----> [batch_size, inputFeatDim, 1]
                train_loss = recon_loss + KL_loss * alpha
                    
                train_loss.backward()
                    
                epoch_loss += train_loss
                epoch_KL  += KL_loss
                optimizer.step()
            
        print("epoch: " + str(epoch)+" loss: " + str(epoch_loss.item()) + " KL: " + str(epoch_KL.item() * alpha))

    
    return VAE, X, latent[0,:,:].unsqueeze(0), X[0,:,N].unsqueeze(0), epoch_loss, optimizer

And, the output of training stage is:

cuda available
data unit is radian
epoch: 0 loss: 30499.1171875 KL: 0.0
epoch: 1 loss: 4208.41015625 KL: 0.0
epoch: 2 loss: 498.6940002441406 KL: 0.0
epoch: 3 loss: 158.99220275878906 KL: 0.0
epoch: 4 loss: 78.93453216552734 KL: 0.0
epoch: 5 loss: 53.533302307128906 KL: 0.0
epoch: 6 loss: 38.02873611450195 KL: 0.0
epoch: 7 loss: 28.048128128051758 KL: 0.0
epoch: 8 loss: 23.194978713989258 KL: 0.0
epoch: 9 loss: 21.458599090576172 KL: 0.0
epoch: 10 loss: 20.632036209106445 KL: 0.0
epoch: 11 loss: 20.297395706176758 KL: 0.0
epoch: 12 loss: 18.75624656677246 KL: 0.0
epoch: 13 loss: 17.753822326660156 KL: 0.0
epoch: 14 loss: 16.912155151367188 KL: 0.0
epoch: 15 loss: 16.498188018798828 KL: 0.0
epoch: 16 loss: 15.184914588928223 KL: 0.0
epoch: 17 loss: 14.235843658447266 KL: 0.0
epoch: 18 loss: 13.30086898803711 KL: 0.0
epoch: 19 loss: 12.536004066467285 KL: 0.0
epoch: 20 loss: 11.863930702209473 KL: 0.0
epoch: 21 loss: 10.70985221862793 KL: 0.0
epoch: 22 loss: 10.140275001525879 KL: 0.0
epoch: 23 loss: 9.719818115234375 KL: 0.0
epoch: 24 loss: 7.877124309539795 KL: 0.0
epoch: 25 loss: 6.41648006439209 KL: 0.0
epoch: 26 loss: 5.2640767097473145 KL: 0.0
epoch: 27 loss: 4.675246238708496 KL: 0.0
epoch: 28 loss: 4.752994060516357 KL: 0.0
epoch: 29 loss: 4.260623455047607 KL: 0.0
epoch: 30 loss: 208.68763732910156 KL: 1771.3675537109375
epoch: 31 loss: 18.421226501464844 KL: 7.0814619064331055
epoch: 32 loss: 16.831327438354492 KL: 0.2619243860244751
epoch: 33 loss: 16.36933708190918 KL: 0.22026295959949493
epoch: 34 loss: 16.225860595703125 KL: 0.1161663681268692
epoch: 35 loss: 16.09817123413086 KL: 0.14859028160572052
epoch: 36 loss: 16.100046157836914 KL: 0.164580836892128
epoch: 37 loss: 15.891282081604004 KL: 0.13011851906776428
epoch: 38 loss: 15.863426208496094 KL: 0.1438782811164856
epoch: 39 loss: 15.77467155456543 KL: 0.0739947035908699
epoch: 40 loss: 15.756997108459473 KL: 0.1154341995716095
epoch: 41 loss: 15.682149887084961 KL: 0.13609440624713898
epoch: 42 loss: 15.646101951599121 KL: 0.14060918986797333
epoch: 43 loss: 15.596468925476074 KL: 0.06942499428987503
epoch: 44 loss: 15.487974166870117 KL: 0.13864728808403015
epoch: 45 loss: 15.456522941589355 KL: 0.09747464954853058
epoch: 46 loss: 15.596013069152832 KL: 0.10960092395544052
epoch: 47 loss: 15.446678161621094 KL: 0.09400694817304611
epoch: 48 loss: 15.414061546325684 KL: 0.07403453439474106
epoch: 49 loss: 15.446662902832031 KL: 0.07924196124076843
epoch: 50 loss: 15.337182998657227 KL: 0.07696129381656647
epoch: 51 loss: 15.423378944396973 KL: 0.1136254072189331
epoch: 52 loss: 15.3486967086792 KL: 0.09196256101131439
epoch: 53 loss: 15.432474136352539 KL: 0.11669618636369705
epoch: 54 loss: 15.23315143585205 KL: 0.08362749963998795
epoch: 55 loss: 15.270442962646484 KL: 0.0592842772603035
epoch: 56 loss: 15.257233619689941 KL: 0.08109745383262634
epoch: 57 loss: 15.207656860351562 KL: 0.058704279363155365
epoch: 58 loss: 15.246068954467773 KL: 0.08804851025342941
epoch: 59 loss: 15.179248809814453 KL: 0.06591930240392685
epoch: 60 loss: 16.24458122253418 KL: 0.05520284175872803
epoch: 61 loss: 18.20315170288086 KL: 0.07300713658332825
epoch: 62 loss: 20.9660701751709 KL: 0.10368426144123077
epoch: 63 loss: 26.014833450317383 KL: 0.1356126070022583
epoch: 64 loss: 35.390743255615234 KL: 0.1684873253107071
epoch: 65 loss: 32.68571090698242 KL: 0.14424605667591095
epoch: 66 loss: 52.215614318847656 KL: 0.26431578397750854
epoch: 67 loss: 189.5343017578125 KL: 1.0707039833068848
epoch: 68 loss: 75.52210235595703 KL: 0.23325027525424957
epoch: 69 loss: 143.2079620361328 KL: 0.38768690824508667
epoch: 70 loss: 157.3100128173828 KL: 0.49191996455192566
epoch: 71 loss: 192.56976318359375 KL: 0.829379677772522
epoch: 72 loss: 258.619873046875 KL: 0.6730182766914368
epoch: 73 loss: 521.1996459960938 KL: 3.7076361179351807
epoch: 74 loss: 330.8260803222656 KL: 0.9579944014549255
epoch: 75 loss: 604.3058471679688 KL: 1.2703361511230469
epoch: 76 loss: 475.0205078125 KL: 0.9360959529876709
epoch: 77 loss: 731.9593505859375 KL: 2.7841150760650635
epoch: 78 loss: 975.5214233398438 KL: 1.2265475988388062
epoch: 79 loss: 924.7633056640625 KL: 0.873565673828125
epoch: 80 loss: 940.7155151367188 KL: 0.5359449982643127
epoch: 81 loss: 855.8935546875 KL: 0.9077990651130676
epoch: 82 loss: 849.4100952148438 KL: 0.7129514813423157
epoch: 83 loss: 743.1096801757812 KL: 0.5308371782302856
epoch: 84 loss: 849.7276611328125 KL: 0.9092111587524414
epoch: 85 loss: 806.3848876953125 KL: 0.49240317940711975
epoch: 86 loss: 773.6209716796875 KL: 0.35794520378112793
epoch: 87 loss: 714.7335815429688 KL: 0.36182066798210144
epoch: 88 loss: 725.5518188476562 KL: 0.6665423512458801
epoch: 89 loss: 725.10498046875 KL: 0.3123415410518646
epoch: 90 loss: 749.900634765625 KL: 0.5664316415786743
epoch: 91 loss: 746.6582641601562 KL: 0.8775449395179749
epoch: 92 loss: 740.4017944335938 KL: 0.4976818263530731
epoch: 93 loss: 709.8568115234375 KL: 0.34913212060928345
epoch: 94 loss: 716.6048583984375 KL: 0.7065077424049377
epoch: 95 loss: 681.2711181640625 KL: 0.36696088314056396
epoch: 96 loss: 740.9374389648438 KL: 0.803412675857544
epoch: 97 loss: 646.1436767578125 KL: 0.2696443796157837
epoch: 98 loss: 664.8652954101562 KL: 0.37316083908081055
epoch: 99 loss: 614.1035766601562 KL: 0.2937750816345215
epoch: 100 loss: 703.1944580078125 KL: 0.4119395315647125
epoch: 101 loss: 644.4376220703125 KL: 0.36282405257225037
epoch: 102 loss: 673.5081176757812 KL: 0.35550656914711
epoch: 103 loss: 599.3011474609375 KL: 0.18692539632320404
epoch: 104 loss: 589.5043334960938 KL: 0.33308255672454834
epoch: 105 loss: 589.5310668945312 KL: 0.20958860218524933
epoch: 106 loss: 633.5597534179688 KL: 0.3015775978565216
epoch: 107 loss: 587.228271484375 KL: 0.2859556972980499
epoch: 108 loss: 633.8538818359375 KL: 0.3062727153301239
epoch: 109 loss: 576.3986206054688 KL: 0.3453579843044281
epoch: 110 loss: 605.309814453125 KL: 0.7614783048629761
epoch: 111 loss: 559.1953735351562 KL: 0.43579205870628357
epoch: 112 loss: 601.722412109375 KL: 0.31123608350753784
epoch: 113 loss: 591.31494140625 KL: 0.38346976041793823
epoch: 114 loss: 677.573974609375 KL: 1.5325040817260742
epoch: 115 loss: 535.7906494140625 KL: 0.2391374409198761
epoch: 116 loss: 550.9417114257812 KL: 0.5806562900543213
epoch: 117 loss: 565.160400390625 KL: 0.31043145060539246
epoch: 118 loss: 584.8384399414062 KL: 0.8044378757476807
epoch: 119 loss: 616.1946411132812 KL: 0.9010312557220459
epoch: 120 loss: 589.0029907226562 KL: 0.5001609325408936
epoch: 121 loss: 558.1272583007812 KL: 0.36073750257492065
epoch: 122 loss: 522.8496704101562 KL: 0.4064602553844452
epoch: 123 loss: 563.9342651367188 KL: 0.2904842495918274
epoch: 124 loss: 562.810791015625 KL: 0.5313525199890137
epoch: 125 loss: 608.248046875 KL: 0.7063066363334656
epoch: 126 loss: 517.7711791992188 KL: 0.2636258602142334
epoch: 127 loss: 525.2127075195312 KL: 0.2245425432920456
epoch: 128 loss: 576.1654663085938 KL: 0.6417035460472107
epoch: 129 loss: 583.733642578125 KL: 0.47674331068992615
epoch: 130 loss: 522.4052124023438 KL: 0.34901681542396545
epoch: 131 loss: 565.4308471679688 KL: 0.232156440615654
epoch: 132 loss: 553.7698364257812 KL: 0.323140025138855
epoch: 133 loss: 586.9306640625 KL: 1.2630860805511475
epoch: 134 loss: 488.27557373046875 KL: 0.43516507744789124
epoch: 135 loss: 527.9531860351562 KL: 0.3459720313549042
epoch: 136 loss: 548.0935668945312 KL: 0.4123835861682892
epoch: 137 loss: 543.787841796875 KL: 0.2853831350803375
epoch: 138 loss: 536.0159912109375 KL: 0.27312254905700684
epoch: 139 loss: 546.4530639648438 KL: 0.5541123151779175