Training Bart as a VAE for interpolation

I’ve bootstrapped the BartModel and BartForConditionalGeneration classes to try to make them into a variational auto encoder. My end goal is to be able to do generation on interpolated embeddings of text. I’m having a hard time wrapping my head around how I can make Bart work for this purpose though. Currently, I’ve made further changes than just the plain VAE changes (where I have modified the loss and introduced some hidden layers for mean and std-dev) I also am using this janky method

def decode(self, hidden, labels):
    out = self.decoder(input_embeds=hidden)[0][:, 0, :].unsqueeze(1)
    out = gelu(out)
    out = self.decoder_output(out)
    hidden =[hidden, out], axis=1)
    if hidden.shape[1]==labels.shape[1]:
        return hidden
    return self.decode(hidden, labels)

the goal here is just to “simplify” the problem into something a simpleton like myself can understand - I’ve tried to make the decoder purely autoregressive so I can take a single embedding from the decoder and vae transform to the latent space and feed it into the decoder and have it auto-regress over it incrementally predicting the output which in my case the labels are the same as the input tokens. This is obviously dreadfully slow to train. Fortunately, my current use case has pretty short sequences <150 tokens on average so I can still fit small batches on the 32gb gpu I’m working on.

Is this the best way to do this? Could I just make do with the standard Bart training method and still generate sequences as I would like from a single embedding with just the decoder? Am I even doing this right? Is there a better option to still leverage pretrained transformers to make a vae?