RoBERTa fine-tuning on a dataset of short sentences and low cardinality

Hi everyone,
I am trying to fine tune the MLM of ā€œroberta-baseā€ on a famous transcription which is the one of the IAM dataset. This is because I would like to use the LM to post process the transcription generated by the handwritten text recogniser which is a classical CRNN. The cardinality of the dataset is 6482 text lines for training, 976 for validation and 2915 for testing.
Iā€™ve trained the model ā€œroberta-baseā€ for 30 epochs with also early stopping. I used the max_length equal to the mean and median length of the lines of text in the transcription (the mean and median number of words, though I know that RoBERTA works with subwords, but I was seeing many <pad> tokens predicted using a longer max_len. The code follows.

BATCH_SIZE = 8
EPOCHS = 30
PATIENCE = 3
MAX_LENGTH = 9
LR = 1E-05

model_name = "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForMaskedLM.from_pretrained(model_name)

import torch, random
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaForMaskedLM, RobertaTokenizer, AdamW

class IamDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=MAX_LENGTH, mask_prob : float=0.15):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mask_prob = mask_prob
    
    def mask_tokens(self, inputs):
        # Identify non-special tokens
        non_special_tokens = [i for i, token_id in enumerate(inputs) if token_id != self.tokenizer.pad_token_id]

        # Calculate the number of tokens to mask
        num_to_mask = int(len(non_special_tokens) * self.mask_prob)

        # Randomly select tokens to mask
        tokens_to_mask = random.sample(non_special_tokens, num_to_mask)

        # Replace selected tokens with [MASK]
        for index in tokens_to_mask:
            inputs[index] = self.tokenizer.mask_token_id

        return inputs
    
    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        # Tokenize the text
        encoding = self.tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length,
            padding="max_length"
        )

        # Labels for MLM training (masked tokens replaced by -100)
        labels = encoding["input_ids"].clone()
        # Mask tokens with the specified probability
        encoding["input_ids"] = self.mask_tokens(encoding["input_ids"].squeeze())


        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),  # MLM training requires labels
        }


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

train_dataset = IamDataset(train_transcriptions, tokenizer)
val_dataset = IamDataset(val_transcriptions, tokenizer)
# define the dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=LR)

scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0, 
    num_training_steps=len(train_loader)*EPOCHS
    )

best_val_loss = float('inf')
patience = PATIENCE

early_stopping_counter = 0
# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} (Train)"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {average_train_loss}")

    # Validation
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        # Validation
        model.eval()
        val_total_loss = 0.0
        with torch.no_grad():
            for batch_val in val_loader:
                input_ids = batch_val["input_ids"].to(device)
                attention_mask = batch_val["attention_mask"].to(device)
                labels = batch_val["labels"].to(device)
                val_outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                val_loss = val_outputs.loss
                val_total_loss += val_loss.item()

    average_val_loss = val_total_loss / len(val_loader)

    print(f"Epoch {epoch + 1}, Train Loss: {average_train_loss}, Validation Loss: {average_val_loss}")


    # Save the model if the validation loss has decreased
    if average_val_loss < best_val_loss:
        best_val_loss = average_val_loss
        # torch.save(model.state_dict(), 'roberta_mlm_model.pth')
        model.save_pretrained("fine_tuned_mlm_roberta")
        tokenizer.save_pretrained("fine_tuned_mlm_roberta")
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print(f"Early stopping after {epoch + 1} EPOCHS.")
            break

Moreover, then I post process the transcription of the htr model considering the words that are out of a general english vocabulary and the training vocabulary and substitute them which the one with the highest probability for the MLM.

# Load the trained model
model = RobertaForMaskedLM.from_pretrained('fine_tuned_mlm_roberta')
tokenizer = RobertaTokenizer.from_pretrained('fine_tuned_mlm_roberta')

print(tokenizer)
# model.load_state_dict(torch.load('roberta_mlm_model.pth'))
model.eval().to('cuda')

th = 0.0

with open(htr_transcription, 'r') as f:
    pred_transcriptions = f.readlines()

post_processed_lines, test_transcriptions_mod = [], []
i = 0
for l in tqdm(pred_transcriptions):
    test_tr_mod_line = tokenizer.encode(test_transcriptions[i])
    test_tr_mod_line = tokenizer.decode(test_tr_mod_line)
    test_tr_mod_line = test_tr_mod_line.replace('<s>', '').replace('</s>', '')
    print(test_tr_mod_line)
    print(l.strip('\n'))
    test_transcriptions_mod.append(test_tr_mod_line)
    
    proc_tokens = tokenizer.encode(l.strip('\n'), return_tensors='pt')[0]
    tokenization = tokenizer.tokenize(l.strip('\n'))
    masked_list = generate_masked_err_strings(l.strip('\n'), train_set_words)

    for input_text in masked_list:
        tokens = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
        with torch.no_grad():
            outputs = model(tokens)
            predictions = outputs.logits  # The predicted logits for each token
            
        try:
            masked_index = tokens[0].tolist().index(tokenizer.mask_token_id)

            matrix_logits = predictions[0, masked_index]
            probabilities = torch.nn.functional.softmax(predictions, -1)
            matrix_probabilities = probabilities[0, masked_index]

            indixes_top_k_predictions = torch.topk(matrix_logits, k=5)[1]
            for el in indixes_top_k_predictions:
                if el not in [0,2,50264]:
                    index = el
                    p_substitute = matrix_probabilities[index]
                    # print(f'p. of the substitute {p_substitute}')
                    if matrix_probabilities[index] > th:
                        proc_tokens[masked_index] = index
                    break
        except:
            # print('no masked token')
            pass
    proc_line = tokenizer.decode(proc_tokens)
    proc_line = proc_line.replace('<s>', '').replace('</s>', '')
    
    
    print(proc_line)
    post_processed_lines.append(proc_line)
    i += 1

However, I have decreased performance.