@dipanjann I’ve wondered the same thing. My solution when the output text corresponds to multiple tokens is the following. I’m not sure it’s completely correct, but the basic idea is to use generate()
to get the prediction logits for each step in the generated output sequence. Next, you roughly calculate the cross entropy loss for each of the possible output classes.
I’m really not sure if this is exactly mathematically correct. I’d love if someone can show a better way! In my tests, it produces reasonable results.
import typing as T
import pytorch_lightning
import torch
from more_itertools import chunked
class MyLitModule(pytorch_lightning.LightningModule):
...
def predict_proba(self, text: T.Iterable[str], labels: T.Iterable[str]):
"""Predict the class probabilities"""
# Compute the tokens corresponding to the text labels:
class_ids = torch.LongTensor(self.tokenizer(list(labels), padding=True).input_ids)
logits = []
for chunk in chunked(text, 16):
# Tokenize the input text:
encoding = self.tokenizer(
list(chunk),
max_length=self.hparams.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
)
output_sequences = self.model.generate(
input_ids=encoding.input_ids.to(self.device),
attention_mask=encoding.attention_mask.to(self.device),
do_sample=False,
return_dict_in_generate=True,
output_scores=True,
min_length=class_ids.shape[1] + 1, max_length=class_ids.shape[1] + 1
)
# Generate the logits for each token in the generated output sequence.
# `scores` has size [batch, seq_length, vocab_size]
scores = torch.stack(output_sequences.scores, dim=1).to("cpu")
# We don't care about the logits of special tokens:
scores[:, :, self.tokenizer.all_special_ids] = torch.nan
# Index the logits in `scores` based on the class token IDs.
# For example, if class_ids[0, :] is [10, 30], then the prediction logits
# are scores[:, 0, 10] and scores[:, 1, 30].
# Finally, we average the logits, which is similar to how the cross entropy loss is calculated.
logits.append(scores.gather(dim=2, index=class_ids.T.expand(len(chunk), -1, -1)))
return torch.concat(logits, dim=0).nanmean(dim=1).softmax(1)