Customizing model architecture from predefined models

Suppose that we want to modify some architecture in the transformers library where should we star?
Are there any tutorials or notebook examples?

For instance I wanted to modify an encoder-decoder model and use only the decoder part.
What should be taken into account in terms of inputs-outputs data pre-post processing.

I tried sub classing the model such as bart or t5 and then use only the encoder part in but that didn’t work.

class T5DecoderOnly(T5ForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        self.model_dim = config.d_model
        self.base_model = self.encoder

model = T5DecoderOnly(config # from pretrained T5)

When printing print(model) it show both encoder and decoder in the architecture.

Also, tried the following:

model = T5ForConditionalGeneration(config)
model = torch.nn.Sequential(*[model.decoder, model.lm_head]

But, again this times throws an error because this time the forward of the decoder does not recognise the arguments input_ids and labels, so calling model(input_ids=input_ids, labels=label_ids) will not work.

I tried to concat inputs and labels and see what happens model(torch.cat([input_ids, label_ids], dim=1))

This time I get a
TypeError: linear(): argument 'input' (position 1) must be Tensor, not BaseModelOutputWithPastAndCrossAttentions

Verbose error output:

TypeError                                 Traceback (most recent call last)
Cell In[148], line 77
     73 label_ids = tokenizer.batch_encode_plus(labels, padding=True, return_tensors="pt", truncation=True)["input_ids"].to(device)
     75 # Forward pass
     76 # outputs = model(input_ids=input_ids, labels=label_ids)
---> 77 outputs = model(torch.cat([input_ids, label_ids], dim=1)) # decoder maybe doesn't take labels arguments and need to be concat with inputs
     78 loss = outputs.loss
     80 # Backward pass

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File /idiap/temp/imitro/miniconda3/envs/length-gen/lib/python3.12/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
    115 def forward(self, input: Tensor) -> Tensor:
--> 116     return F.linear(input, self.weight, self.bias)

Any ideas how to resolve this?