TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead

I’m trying to run the most basic example of the trl package:

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

I am having trouble running it on a mac because it doesn’t seem compatible with Mac’s MPS.

I have tried to use device = torch.device(“cpu”) instead of MPS and used:
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2').to(device)
And set all tensors and the tokenizer to run with device = “cpu”.

But I’m still getting an error at the end:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use 
float32 instead.

If none of the tensors & models are running on MPS, why does it still say it’s on an MPS tensor?

I’m not sure why it’s always running on the MPS device but I just had the same problem.

A fix that seems to at least get rid of the errors is to change from double to float in line 1031 and 1039 of trl/trainer/ppo_trainer.py. So

vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).double(), mask)

to

vf_clipfrac = masked_mean(torch.gt(vf_losses2, vf_losses1).float(), mask)

and the same for line 1039.

I am not sure if this will cause any problems further down the line though.

1 Like

Worked for me, thanks !