Using ViTMAEModel as an encoder for a UNet decoder for semantic segmentation

I want to train a Masked Autoencoder (ViTMAE) on a downstream task. The task is semantic segmentation and I have been trying to use the ViTMAEModel function as the encoder. I just don’t know what to do to be able to properly pass the last hidden state and all the other states into the decoder. Any ideas?

class UNetDecoder(nn.Module):
def init(self,config):
super(UNetDecoder, self).init()
#self.config = config

    self.up1 = nn.ConvTranspose2d(in_channels=768,out_channels=512,kernel_size=2,stride=2)
    self.up2 = nn.ConvTranspose2d(512,256,kernel_size=2,stride=2)
    self.up3 = nn.ConvTranspose2d(256,128,kernel_size=2,stride=2)
    self.up4 = nn.ConvTranspose2d(128,64,kernel_size=2,stride=2)
    self.out_conv = nn.Conv2d(64, config.num_channels, kernel_size=2)

def forward(self, skip_connection):
    
    x = self.up1(skip_connection[-1])
    x = torch.cat([x,skip_connection[-2]])
    
    x = self.up2(x)
    x = torch.cat([x,skip_connection[-3]],dim=1)

    x = self.up3(x)
    x = torch.cat([x,skip_connection[-4]],dim=1)
    
    x = self.up4(x)
    x = self.out_conv(x)
    
    return x