Finetune T5 with T5ForConditionalGeneration to multitask for Q&A and Summarization

Hi everyone, my end goal is to have a fine-tuned T5 model that can perform Q&A as well as summarization. I can train each of these tasks independently using the various AutoModels (eg: AutoModelForQuestionAnswering) but when I train the model using T5ForConditionalGeneration I don’t think I am formatting the Q&A inputs in the pre-process function .

Question 1: is T5ForConditionalGeneration appropriate for Q/A, keeping in mind that I also need to support summarization?

Question 2: is there a barebones example / explanation on how you would format the model inputs / labels when using T5ForConditionalGeneration for Q&A? The summarizer works when I train the combined datasets, but the Q/A gives terrible / random results. I am sure it’s how I am tokenizing the Q&A dataset so if anyone has an example using T5ForConditionalGeneration I would appreciate it.

Here is my preprocess function for the Q/A dataset:

def encode_qa(example,
           encoder_max_len=max_input_length, decoder_max_len=max_target_length):
  
    context = example['context']
    question = example['question']
    answer = example['answers']['text']
  
    question_plus = f"{str(question)}"
    question_plus += f" context: {str(context)} </s>"
    
    answer_plus = ', '.join([i for i in list(answer)])
    answer_plus = f"{answer_plus} </s>"
    
    encoder_inputs = tokenizer(question_plus, truncation=True, 
                               return_tensors='pt', max_length=encoder_max_len,
                              pad_to_max_length=True)
    
    decoder_inputs = tokenizer(answer_plus, truncation=True, 
                               return_tensors='pt', max_length=decoder_max_len,
                              pad_to_max_length=True)
    
    input_ids = encoder_inputs['input_ids'][0]
    input_attention = encoder_inputs['attention_mask'][0]
    target_ids = decoder_inputs['input_ids'][0]
    target_attention = decoder_inputs['attention_mask'][0]
    
    outputs = {'input_ids':input_ids, 'attention_mask': input_attention, 
               'labels':target_ids, 'decoder_attention_mask':target_attention}
    return outputs