Pegasus tokenizer for batch processing

Hi to everyone, i’m trying to play with some trasformers for learn the basics of this amazing world, but in the last days i got stuck with pegasus model. I’m trying to summarize text feature from a dataset for make a summarize enought short for Bert tokenizer, using pegasus as the summary model. When i run the code using the function map with batched = False evrething is working fine, but if i turn to batched = True i get :

TypeError: TextEncodeInput must be Union[TextInputSequence, Tuple[InputSequence, InputSequence]]

I already:

  • checked if there are none values and removed
  • checked if i was giving a list of string as input of tokenizer
  • tried to generate a list of string from batch_samples and give it as input

This is my code, while you can find the dataset here

from datasets import Dataset, DatasetDict, concatenate_datasets
import pandas as pd
from transformers import AutoModel, AutoTokenizer, AutoModelForSeq2SeqLM
import torch

def preprocess_data(df):
    #adding token_lenght column
    df["lb_num_token"] = d_len
    
    #Dropping Nan values
    df = df.dropna(subset=['case_text'])

    # Dropping unused features and renaming columns
    df = df.drop(columns =['case_id', 'case_title'])
    df.rename(columns={"case_text":"text", "case_outcome":"label"}, inplace= True)

    # Get the number of unique labels
    labels_list = df["label"].unique().tolist()
    
    # Splitting Dataset
    df = Dataset.from_pandas(df)
    df = df.map(lambda example: {'text': str(example['text'])})
    train_valid = df.train_test_split(test_size= 0.2, seed= 42)
    valid_test  = train_valid["test"].train_test_split(test_size= 0.5, seed= 42)
    
    df_split = DatasetDict({
    'train': train_valid['train'],
    'valid': valid_test['train'],
    'test': valid_test['test']
    })
    
    return df_split, labels_list

#Loading Dataset
df = pd.read_csv("./datasets/legal_text_classification.csv")

# number of bert token for each sample
model_ckpt = "nlpaueb/legal-bert-small-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)  
d_len = [len(tokenizer.encode(str(s))) for s in df["case_text"]]

# preprocessing dataset
df, labels_list = preprocess_data(df)

train = Dataset.from_dict(df["train"][0:9])

def pegasus_summary(batch_samples, model, tokenizer):
    # This function take in input a batch of samples and return the summary of each sample.
    # The summary length is set to 400 token length, because the output summary will be used as bert tokenizer input
    # LLM used: legal-pegasus
    # It will be better to call this function with model anf tokenizer already define inside the main code

    summary = ""
    # summary
    input_tokenized = tokenizer.encode(batch_samples["text"], return_tensors='pt', max_length=1024, truncation=True).to(device)
    with torch.no_grad():
        summary_ids = model.generate(input_tokenized,
                                     num_beams=9,
                                     no_repeat_ngram_size=3,
                                     length_penalty=2.0,
                                     min_length=150,
                                     max_length=400,
                                     early_stopping=True)

    summary = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids][0]
    return {"text": summary}

def summarizing_samples(df):
    model_ckpt_sum = "nsi319/legal-pegasus"
    tokenizer_sum = AutoTokenizer.from_pretrained(model_ckpt_sum)
    model_sum = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt_sum).to(device)
    
    df_long = df.filter(lambda example: example["lb_num_token"] > 512)
    df_short= df.filter(lambda example: example["lb_num_token"] <= 512)

    df_long = df_long.map(lambda example: pegasus_summary(example, model_sum, tokenizer_sum), batched = True)
                                                                                          
    df = concatenate_datasets([df_long, df_short])
    return df

device = "cuda" if torch.cuda.is_available() else "cpu"
train = summarizing_samples(train)

for it in train["text"]:
    print(it, "\n\n\n")

Thanks a lot for your time and i hope that my english is understandable.

After couple days of researching and documantation reading i discover my mistake. In this code i was using tokenizer.encode() as function for tokenize inputs, but this fuction works only with one sample at time. I simply updated the above function whit tokenizer() as encoding function and everything works now.