Hi,
It’s my first time training a model. I fine-tuned a T5 model to create sentences that contain a specified word. The initial training went well: the model generates coherent sentences, and about 75% of the generated sentences include the target word.
Then, I thought I could modify the loss function to add a penalty for not including the target word. This way, I could improve the precision of the target word usage.
I came up with the following code and integrated it into a custom trainer. The code is supposed to evaluate the probabilities of the target sequence (the target word may consist of more than one token) in the predicted logits and then incorporate it into the base loss:
def custom_loss(model, inputs, target_word_ids, alpha=5.0):
# 1. Forward pass
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]
# 2. Loss base de cross entropy
loss_fct = CrossEntropyLoss(ignore_index=model.config.pad_token_id)
base_loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
# 3. Calcular probabilidades
probs = torch.softmax(logits, dim=-1) # [batch_size, seq_len, vocab_size]
batch_size = logits.size(0)
target_probs = torch.zeros(batch_size, device=logits.device)
# 4. Calcular penalización para cada secuencia en el batch
for i in range(batch_size):
target_seq = target_word_ids[i][target_word_ids[i] != 0] # Eliminar padding
if len(target_seq) > 0:
# Tomar los primeros N tokens
n_tokens = min(2, len(target_seq))
tokens_to_check = target_seq[:n_tokens]
# Calcular probabilidad para cada token
token_probs = []
for token in tokens_to_check:
# Obtener probabilidades para el token en todas las posiciones
probs_for_token = probs[i, :, token]
""" Prob tiene la probabilidad suave de que la secuencia de tokens aparezca en toda la frase, le da más
importancia a los valores altos.
Se podría coger solo la probabilidad de la posición más probable, pero entonces los gradientes solo
se propagarían hacia los pesos que generaron esa posición de la secuencia, con esto se supone que se
propaga más "suave" (caliente/frío) hacia todos los pesos
"""
prob = torch.logsumexp(probs_for_token * 5.0, dim=0) / 5.0
token_probs.append(prob)
# Promedio de las probabilidades de los tokens
target_probs[i] = sum(token_probs) / len(token_probs)
# 5. Calcular penalización final, usar log para que una probabilidad alta penalice poco
penalty = -torch.log(target_probs + 1e-10).mean()
# 6. Combinar losses
total_loss = base_loss + alpha * penalty
outputs.loss = total_loss
return total_loss, outputs
class CustomTrainer(Seq2SeqTrainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
if "target_word" in inputs: # Para training y validation loss
target_word = inputs.pop("target_word")
loss, outputs = custom_loss(model, inputs, target_word, alpha=ALPHA_LOSS)
return (loss, outputs) if return_outputs else loss
else: # Para generación
outputs = model(**inputs)
loss = outputs.loss
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model,
inputs,
prediction_loss_only,
ignore_keys=None,
):
inputs_copy = {}
expected_keys = ["input_ids", "attention_mask"]
for key in expected_keys:
if key in inputs:
inputs_copy[key] = inputs[key]
return super().prediction_step(
model,
inputs_copy,
prediction_loss_only,
ignore_keys=ignore_keys,
)
I trained a model using different combinations of Alpha, but it doesn’t seem to improve the target word accuracy.
Does anyone know what might be wrong? Could it be that the gradients are not propagating correctly? Or maybe the loss function doesn’t make sense at all?
Any help would be appreciated.