Hi! I’m working with the following code
dataset = load_dataset("yuntian-deng/im2latex-100k") dataset = dataset.with_format("torch", device=device) train_dataset = dataset["train"] test_dataset = dataset["test"] val_dataset = dataset["val"] train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
I want to feed these dataloaders into my model, but I can’t do it, because the
with_format method converts the image column of
dataset into torch tensors with
dtype=uint8. I need the dtype to be
float. I’ve been trying to figure out how to cast this to float but I can’t find a way to do it. I’d appreciate some help here on the intended way to do this.
Ie, I want
with_format("torch") to convert the image column of the data to tensors with
dtype=torch.float. It is currently converting it to tensors with
dtype=torch.uint8. This is not what I want.