PyTorch Bilinear messing with HuggingFace BERT?!

Hello!

I am trying to train embeddings. In my model, I have two BERT layers. The output from the BERT layers is fed to a Bilinear layer. I am training this model using a triplet loss function. I am pasting my code below:

class MyModel(nn.Module):
    def __init__(self, pretrained_bert_model: str, out_dim: int):

        super(MyModel, self).__init__()

        self.context_encoder = BertEmbedding(pretrained=pretrained_bert_model)

        self.entity_encoder = BertEmbedding(pretrained=pretrained_bert_model)

        self.linear = nn.Linear(768, 100)

        self.bilinear = nn.Bilinear(in1_features=out_dim, in2_features=out_dim, out_features=out_dim)


    def forward(
            self,
            context_input_ids,
            context_attention_mask,
            desc_input_ids,
            desc_attention_mask,
    ) -> torch.Tensor:



        # Get context embedding
        context_embedding: torch.Tensor = self.context_encoder(
            input_ids=context_input_ids,
            attention_mask=context_attention_mask,
        )
        # Reduce dimension from 768 to 100
        context_embedding = self.linear(context_embedding)

        entity_embedding: torch.Tensor = self.entity_encoder(
            input_ids=desc_input_ids,
            attention_mask=desc_attention_mask,
        )
        entity_embedding = self.linear(entity_embedding)

        
        out_embedding = torch.Tensor = self.bilinear(context_embedding, entity_embedding)
        return out_embedding



class BertEmbedding(nn.Module):
    def __init__(self, pretrained: str) -> None:

        super(BertEmbedding, self).__init__()
        self.pretrained = pretrained
        self.config = AutoConfig.from_pretrained(self.pretrained)
        self.bert = AutoModel.from_pretrained(self.pretrained, config=self.config)

    def forward(self,input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return output[0][:, 0, :]

In the training loop:

            anchor_embedding = self.model(
                context_input_ids=batch['context_input_ids'].to(self.device),
                context_attention_mask=batch['context_attention_mask'].to(self.device),
                desc_input_ids=batch['anchor_desc_input_ids'].to(self.device),
                desc_attention_mask=batch['anchor_desc_attention_mask'].to(self.device),
            )
            pos_embedding = self.model(
                context_input_ids=batch['context_input_ids'].to(self.device),
                context_attention_mask=batch['context_attention_mask'].to(self.device),
                desc_input_ids=batch['pos_desc_input_ids'].to(self.device),
                desc_attention_mask=batch['pos_desc_attention_mask'].to(self.device),
            )
            neg_embedding = self.model(
                context_input_ids=batch['context_input_ids'].to(self.device),
                context_attention_mask=batch['context_attention_mask'].to(self.device),
                desc_input_ids=batch['neg_desc_input_ids'].to(self.device),
                desc_attention_mask=batch['neg_desc_attention_mask'].to(self.device),
            )


            batch_loss = self.loss_fn(anchor_embedding, pos_embedding, neg_embedding)

The code runs fine and gives me anchor_embedding. But when it tries to produce pos_embedding it fails with the error: TypeError: isinstance() arg 2 must be a type or tuple of types. This happens at the line output = self.bert(input_ids=input_ids, attention_mask=attention_mask)

Just for checking, I replaced the Bilinear layer with a Linear layer and passed the concatenation of context_embedding and entity_embedding and the code runs fine!
I have been breaking my head over this but unable to figure out what the issue is. If anyone has any insights, I would be most grateful.

Thank you!