Hi! I’m working with the following code
dataset = load_dataset("yuntian-deng/im2latex-100k")
dataset = dataset.with_format("torch", device=device)
train_dataset = dataset["train"]
test_dataset = dataset["test"]
val_dataset = dataset["val"]
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
I want to feed these dataloaders into my model, but I can’t do it, because the with_format
method converts the image column of dataset
into torch tensors with dtype=uint8
. I need the dtype to be float
. I’ve been trying to figure out how to cast this to float but I can’t find a way to do it. I’d appreciate some help here on the intended way to do this.
Ie, I want with_format("torch")
to convert the image column of the data to tensors with dtype=torch.float
. It is currently converting it to tensors with dtype=torch.uint8
. This is not what I want.
Thanks!