RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got MPSFloatType instead (while checking arguments for embedding)

I am trying to train a multi-modal model by taking in image and text input to output a text.

Here is my architecture;
(Assuming batch size=1)
I use a ViT (from hugging face) to convert images (1, 3, 224,224) to tokens (1, 588) → Float datatype
I have a text tokenizer which creates text tokens of dimension (1, 512). → Int datatype

In order to utilize both the features, I am using T5 model. When I concatenate both the outputs, my actual_input to T5 would be of shape (1, 1024) → Float datatype

But seems like output = model(input_ids=actual_input, labels=labels) is throwing the following error

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got MPSFloatType instead (while checking arguments for embedding)

stating that input_ids should be of int/long type. But my doubt is, if I typecast the actual_input tensor to torch.int64 will it be differentiable?

Here are some code snippets for clarity;

  • model declaration
# T5
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer2 = T5Tokenizer.from_pretrained("google-t5/t5-small")
model2 = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")

max_source_length = 35
max_target_length = 512

# ViT
from transformers import ViTConfig, ViTModel

# Initializing a ViT vit-base-patch16-224 style configuration
configuration = ViTConfig()
configuration.hidden_size = 588

# Initializing a model (with random weights) from the vit-base-patch16-224 style configuration
model3 = ViTModel(configuration)

  • Train loop

from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
train_loss = [0] * num_epochs
model2.train()
model3.train()
for epoch in range(num_epochs):
    for it, batch in enumerate(train_dataloader):
        pixel_values = batch['pixel_values']
        input2 = batch['input_ids']
        labels = batch['labels']
        pixel_values = pixel_values.to(device); labels = labels.to(device); input2 = input2.to(device)
        
        # concatenate tensors
        out = model3(pixel_values)['pooler_output']
        actual_input = torch.cat((out, input2), dim=1).to(torch.int64)

        outputs = model2(input_ids=actual_input, labels=labels)    # this line gives error if I don't typecast

        loss = outputs.loss
        loss.backward()
        optimizer.step()

        optimizer.zero_grad()                                                            
        progress_bar.update(1)                                                                        
        train_loss[epoch]  += loss.item() * pixel_values.shape[0]                                                 
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 
    train_loss[epoch] = train_loss[epoch]/len(train_dataloader.dataset)
    
    if epoch % 2 == 0:
        torch.save(model.state_dict(), f'(fused)model_state_epoch_{epoch}.pth')
    print(f'Epoch {epoch+1} Loss: {train_loss[epoch]:.4f}')
                                                                                                                              

Also by this approach, will ViT model’s weight be also tuned?
Any suggestions are welcome, even for a better architectural approach for my task.

Im having the same issue. Did you ever figure it out?