How to combine two models' logits

Hi,

I want to perform text generation by combining the logits of two existing language models in various ways (these models both have causal LM heads). What is the best way to do this? I’ve tried to subclass PreTrainedModel to contain the two models and then output concatenations of the two models’ logits, but the configuration and initialization methods are more geared towards saving and loading existing models rather than combining existing models, so this hasn’t worked out so well. It’s easy to do this kind of task in standard pytorch for vision models, is there a simple way to do this in Huggingface that I’m missing?

Thank you for the help!

You should be able to create a pytorch model with each of the huggingface models initialized as layers of the model. Then in the forward function for the pytorch model, pass the inputs through self.model_a and self.model_b to get logits from both. You can concatenate these there and pass them through the rest of the model. I’ve written the PSEUDOCODE (this code won’t run directly, but presents the general idea) for the same below:

import torch.nn as nn
from transformers import AutoModel

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model_a = AutoModel.from_pretrained('model_a')
        self.model_b = AutoModel.from_pretrained('model_b')

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(768, 768, bias=True),
            nn.Tanh(),
            nn.Dropout(p=0.1),
            nn.Linear(768, 3, bias=True)
        )

    def forward(self, input_ids, attention_mask):
        logits_a = self.model_a(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        logits_b = self.model_b(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
        concatenated_vectors = torch.concat(logits_a, logits_b)
        output = self.classifier(concatenated_vectors)
        return output

model = Net()

You can just train this model like how you train a regular Pytorch model.

Edit: Made a small error in the code by passing x to classifier instead of concatenated_vectors.

6 Likes

I want to similarly combine output logits, to build a new one that basically compares two inputs. Right now my code is something like

examples = movie_double.iloc[0]
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
input_0 = tokenizer(examples["text_0"], truncation=True)
input_1 = tokenizer(examples["text_1"], truncation=True)
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
a = model(input_0)

but this already fails

---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

/usr/local/lib/python3.8/dist-packages/transformers/tokenization_utils_base.py in __getattr__(self, item)
    247         try:
--> 248             return self.data[item]
    249         except KeyError:

KeyError: 'size'


During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)

5 frames

<ipython-input-45-22a73547024f> in <module>
     37 print(input_0)
     38 model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
---> 39 a = model(input_0)
     40 print(a)

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.8/dist-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict)
    759         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    760 
--> 761         distilbert_output = self.distilbert(
    762             input_ids=input_ids,
    763             attention_mask=attention_mask,

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1188         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1189                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1190             return forward_call(*input, **kwargs)
   1191         # Do not call functions when jit is used
   1192         full_backward_hooks, non_full_backward_hooks = [], []

/usr/local/lib/python3.8/dist-packages/transformers/models/distilbert/modeling_distilbert.py in forward(self, input_ids, attention_mask, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict)
    561             raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
    562         elif input_ids is not None:
--> 563             input_shape = input_ids.size()
    564         elif inputs_embeds is not None:
    565             input_shape = inputs_embeds.size()[:-1]

/usr/local/lib/python3.8/dist-packages/transformers/tokenization_utils_base.py in __getattr__(self, item)
    248             return self.data[item]
    249         except KeyError:
--> 250             raise AttributeError
    251 
    252     def __getstate__(self):

AttributeError: 

My objective would be to get something like this model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

    def forward(self, input_0, input_1):
        logits_a = self.model(input_0).last_hidden_state[:, 0, :]
        logits_b = self.model(input_1).last_hidden_state[:, 0, :]
        logits = torch.einsum('i,j->ij',logits_a,logits_b)
        logits = torch.reshape(logits, (-1,))
        output = torch.tensor([logits[1], logits[0]+logits[3], logits[2]])
        return output

torch_model = Net()

which basically outputs probabilities that output a and b are equal or different, and how so.

+1 it seems interesting if there is any ref.

1 Like