How to change the datatype of a dataset after it has been converted to torch with huggingface images?

Hi! I’m working with the following code

dataset = load_dataset("yuntian-deng/im2latex-100k")
dataset = dataset.with_format("torch", device=device)
train_dataset = dataset["train"]
test_dataset = dataset["test"]
val_dataset = dataset["val"]
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

I want to feed these dataloaders into my model, but I can’t do it, because the with_format method converts the image column of dataset into torch tensors with dtype=uint8. I need the dtype to be float. I’ve been trying to figure out how to cast this to float but I can’t find a way to do it. I’d appreciate some help here on the intended way to do this.

Ie, I want with_format("torch") to convert the image column of the data to tensors with dtype=torch.float. It is currently converting it to tensors with dtype=torch.uint8. This is not what I want.

Thanks!

You can achieve this by passing dtype=torch.float to the with_format call.