Tokenization on dataset produces invalid pytorch tensor shape

I’m attempting to tokenize text in a datasets object and then perform model inference with the tokenized data. I’m using very similar code to this example in the datasets documentation.

The issue I’m running into is that the pytorch tensors for the tokens end up in the wrong shape. I have a reproducible example below where I compare tokenizing a single string compared to tokenizing the dataset with map(). These methods produce tensor shapes of torch.Size([1, 23]) and torch.Size([23]) respectively.

How can I adjust the dataset processing so that I get the expected tensor shape (torch.Size([1, 23]))?

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import time, os

model_arch="distilbert-base-uncased"

# import raw text data
raw_data = load_dataset("glue","cola",split='train').shuffle(seed=42).select(range(10))
text_field='sentence'

model=AutoModelForSequenceClassification.from_pretrained(model_arch)
tokenizer=AutoTokenizer.from_pretrained(model_arch)

# tokenize a single row
adhoc_tokenized_data=tokenizer(raw_data[0][text_field],return_tensors='pt')

# tokenized the entire raw text dataset
tokenized_data = raw_data.map(
    lambda x: tokenizer(x[text_field]),
    remove_columns=[text_field, 'idx']
)
tokenized_data=tokenized_data.rename_column('label', 'labels') # forward() expects 'labels'
tokenized_data.set_format('torch')
print(tokenized_data.format) 

# look at shape of tensors
print(adhoc_tokenized_data['input_ids'].shape)  #   torch.Size([1, 23])
print(tokenized_data[0]['input_ids'].shape)     #   torch.Size([23])

# Attempt to run inference on a single row of data --------------------------------------------

# inference works on the single row of tokenized data
print(model(**adhoc_tokenized_data))

# inference fails on the first row of the tokenized dataset object
# Error from forward(): IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
print(model(**tokenized_data[0]))

A model expects a batch as input, but you are passing tokenized_data[0] which is a single row.

Instead, you can pass a batch of one or several rows with tokenized_data[:1] or tokenized_data[:n_rows]:

print(model(**tokenized_data[:1]))
1 Like