Suppose that we want to modify some architecture in the transformers library where should we star?
Are there any tutorials or notebook examples?
For instance I wanted to modify an encoder-decoder model and use only the decoder part.
What should be taken into account in terms of inputs-outputs data pre-post processing.
I tried sub classing the model such as bart or t5 and then use only the encoder part in but that didn’t work.
class T5DecoderOnly(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.model_dim = config.d_model
self.base_model = self.encoder
model = T5DecoderOnly(config # from pretrained T5)
When printing print(model)
it show both encoder
and decoder
in the architecture.
Also, tried the following:
model = T5ForConditionalGeneration(config)
model = torch.nn.Sequential(*[model.decoder, model.lm_head]
But, again this times throws an error because this time the forward
of the decoder
does not recognise the arguments input_ids
and labels
, so calling model(input_ids=input_ids, labels=label_ids)
will not work.
I tried to concat inputs and labels and see what happens model(torch.cat([input_ids, label_ids], dim=1))
This time I get a
TypeError: linear(): argument 'input' (position 1) must be Tensor, not BaseModelOutputWithPastAndCrossAttentions
Verbose error output:
TypeError Traceback (most recent call last)
Cell In[148], line 77
73 label_ids = tokenizer.batch_encode_plus(labels, padding=True, return_tensors="pt", truncation=True)["input_ids"].to(device)
75 # Forward pass
76 # outputs = model(input_ids=input_ids, labels=label_ids)
---> 77 outputs = model(torch.cat([input_ids, label_ids], dim=1)) # decoder maybe doesn't take labels arguments and need to be concat with inputs
78 loss = outputs.loss
80 # Backward pass
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
215 def forward(self, input):
216 for module in self:
--> 217 input = module(input)
218 return input
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)
Any ideas how to resolve this?