Hi everyone,
I am trying to fine tune the MLM of “roberta-base” on a famous transcription which is the one of the IAM dataset. This is because I would like to use the LM to post process the transcription generated by the handwritten text recogniser which is a classical CRNN. The cardinality of the dataset is 6482 text lines for training, 976 for validation and 2915 for testing.
I’ve trained the model “roberta-base” for 30 epochs with also early stopping. I used the max_length
equal to the mean and median length of the lines of text in the transcription (the mean and median number of words, though I know that RoBERTA works with subwords, but I was seeing many <pad>
tokens predicted using a longer max_len
. The code follows.
BATCH_SIZE = 8
EPOCHS = 30
PATIENCE = 3
MAX_LENGTH = 9
LR = 1E-05
model_name = "roberta-base"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForMaskedLM.from_pretrained(model_name)
import torch, random
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaForMaskedLM, RobertaTokenizer, AdamW
class IamDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=MAX_LENGTH, mask_prob : float=0.15):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
self.mask_prob = mask_prob
def mask_tokens(self, inputs):
# Identify non-special tokens
non_special_tokens = [i for i, token_id in enumerate(inputs) if token_id != self.tokenizer.pad_token_id]
# Calculate the number of tokens to mask
num_to_mask = int(len(non_special_tokens) * self.mask_prob)
# Randomly select tokens to mask
tokens_to_mask = random.sample(non_special_tokens, num_to_mask)
# Replace selected tokens with [MASK]
for index in tokens_to_mask:
inputs[index] = self.tokenizer.mask_token_id
return inputs
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Tokenize the text
encoding = self.tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
padding="max_length"
)
# Labels for MLM training (masked tokens replaced by -100)
labels = encoding["input_ids"].clone()
# Mask tokens with the specified probability
encoding["input_ids"] = self.mask_tokens(encoding["input_ids"].squeeze())
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"labels": encoding["input_ids"].squeeze(), # MLM training requires labels
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
train_dataset = IamDataset(train_transcriptions, tokenizer)
val_dataset = IamDataset(val_transcriptions, tokenizer)
# define the dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=LR)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=0,
num_training_steps=len(train_loader)*EPOCHS
)
best_val_loss = float('inf')
patience = PATIENCE
early_stopping_counter = 0
# Training loop
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0
for batch in tqdm(train_loader, desc=f"Epoch {epoch} (Train)"):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
total_loss += loss.item()
average_train_loss = total_loss / len(train_loader)
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {average_train_loss}")
# Validation
model.eval()
total_val_loss = 0
with torch.no_grad():
# Validation
model.eval()
val_total_loss = 0.0
with torch.no_grad():
for batch_val in val_loader:
input_ids = batch_val["input_ids"].to(device)
attention_mask = batch_val["attention_mask"].to(device)
labels = batch_val["labels"].to(device)
val_outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
val_loss = val_outputs.loss
val_total_loss += val_loss.item()
average_val_loss = val_total_loss / len(val_loader)
print(f"Epoch {epoch + 1}, Train Loss: {average_train_loss}, Validation Loss: {average_val_loss}")
# Save the model if the validation loss has decreased
if average_val_loss < best_val_loss:
best_val_loss = average_val_loss
# torch.save(model.state_dict(), 'roberta_mlm_model.pth')
model.save_pretrained("fine_tuned_mlm_roberta")
tokenizer.save_pretrained("fine_tuned_mlm_roberta")
early_stopping_counter = 0
else:
early_stopping_counter += 1
if early_stopping_counter >= patience:
print(f"Early stopping after {epoch + 1} EPOCHS.")
break
Moreover, then I post process the transcription of the htr model considering the words that are out of a general english vocabulary and the training vocabulary and substitute them which the one with the highest probability for the MLM.
# Load the trained model
model = RobertaForMaskedLM.from_pretrained('fine_tuned_mlm_roberta')
tokenizer = RobertaTokenizer.from_pretrained('fine_tuned_mlm_roberta')
print(tokenizer)
# model.load_state_dict(torch.load('roberta_mlm_model.pth'))
model.eval().to('cuda')
th = 0.0
with open(htr_transcription, 'r') as f:
pred_transcriptions = f.readlines()
post_processed_lines, test_transcriptions_mod = [], []
i = 0
for l in tqdm(pred_transcriptions):
test_tr_mod_line = tokenizer.encode(test_transcriptions[i])
test_tr_mod_line = tokenizer.decode(test_tr_mod_line)
test_tr_mod_line = test_tr_mod_line.replace('<s>', '').replace('</s>', '')
print(test_tr_mod_line)
print(l.strip('\n'))
test_transcriptions_mod.append(test_tr_mod_line)
proc_tokens = tokenizer.encode(l.strip('\n'), return_tensors='pt')[0]
tokenization = tokenizer.tokenize(l.strip('\n'))
masked_list = generate_masked_err_strings(l.strip('\n'), train_set_words)
for input_text in masked_list:
tokens = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
with torch.no_grad():
outputs = model(tokens)
predictions = outputs.logits # The predicted logits for each token
try:
masked_index = tokens[0].tolist().index(tokenizer.mask_token_id)
matrix_logits = predictions[0, masked_index]
probabilities = torch.nn.functional.softmax(predictions, -1)
matrix_probabilities = probabilities[0, masked_index]
indixes_top_k_predictions = torch.topk(matrix_logits, k=5)[1]
for el in indixes_top_k_predictions:
if el not in [0,2,50264]:
index = el
p_substitute = matrix_probabilities[index]
# print(f'p. of the substitute {p_substitute}')
if matrix_probabilities[index] > th:
proc_tokens[masked_index] = index
break
except:
# print('no masked token')
pass
proc_line = tokenizer.decode(proc_tokens)
proc_line = proc_line.replace('<s>', '').replace('</s>', '')
print(proc_line)
post_processed_lines.append(proc_line)
i += 1
However, I have decreased performance.