get batch indices when iterating DataLoader over a Dataset

The code below is taken from the tutorial

from datasets import load_metric

metric= load_metric("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)
    
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

Inside the loop for batch in eval_dataloader:, how can I know which indices from the dataset this batch includes?

The DataLoader is created earlier using

eval_dataloader = DataLoader(
    tokenized_datasets["validation"], batch_size=8, collate_fn=data_collator
)

Note that it’s without the shuffling flag, so it’s possible to count manually using batch size, but how to do it with shuffling? Is it possible to to make it a field of the batch when creating the dataset and dataloader?

Hi ! I think it’s possible to get the indices used by the sampler (see Indices of a dataset sampled by DataLoader - PyTorch Forums for example).