How to combine two models' logits


I want to perform text generation by combining the logits of two existing language models in various ways (these models both have causal LM heads). What is the best way to do this? I’ve tried to subclass PreTrainedModel to contain the two models and then output concatenations of the two models’ logits, but the configuration and initialization methods are more geared towards saving and loading existing models rather than combining existing models, so this hasn’t worked out so well. It’s easy to do this kind of task in standard pytorch for vision models, is there a simple way to do this in Huggingface that I’m missing?

Thank you for the help!

You should be able to create a pytorch model with each of the huggingface models initialized as layers of the model. Then in the forward function for the pytorch model, pass the inputs through self.model_a and self.model_b to get logits from both. You can concatenate these there and pass them through the rest of the model. I’ve written the PSEUDOCODE (this code won’t run directly, but presents the general idea) for the same below:

import torch.nn as nn
from transformers import AutoModel

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model_a = AutoModel.from_pretrained('model_a')
        self.model_b = AutoModel.from_pretrained('model_b')

        self.classifier = nn.Sequential(
            nn.Linear(768, 768, bias=True),
            nn.Linear(768, 3, bias=True)

    def forward(self, input_ids, attention_mask):
        logits_a = self.model_a(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        logits_b = self.model_b(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        concatenated_vectors = torch.concat(logits_a, logits_b)
        output = self.classifier(concatenated_vectors)
        return output

model = Net()

You can just train this model like how you train a regular Pytorch model.

Edit: Made a small error in the code by passing x to classifier instead of concatenated_vectors.