The code below is taken from the tutorial
from datasets import load_metric
metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
metric.compute()
Inside the loop for batch in eval_dataloader:
, how can I know which indices from the dataset this batch includes?
The DataLoader is created earlier using
eval_dataloader = DataLoader(
tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)
Note that it’s without the shuffling flag, so it’s possible to count manually using batch size, but how to do it with shuffling? Is it possible to to make it a field of the batch when creating the dataset and dataloader?