Environment:
transformers
version: 4.45.2datasets
version: 3.0.1- Platform: Linux-5.15.0-1070-aws-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.26.1
- Safetensors version: 0.4.5
- Accelerate version: 1.0.1
- Accelerate config: not found
- PyTorch version (GPU?): 2.5.0+cu118 (True)
- Tensorflow version (GPU?): 2.14.1 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes
- GPU type: NVIDIA A10G
Bug description:
I am unable to train a model using both bfloat16 and torch compile, I am getting RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::BFloat16
even though all parameters in the model appear to have torch.bfloat16
dtype (see script below). When disabling torch compilation or using float32
(or doing both), everything works fine.
Minimal reproducible example:
import torch
from transformers import pipeline
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
# Load classification pipeline from pretrained model
pipe = pipeline(
"text-classification",
model="Qwen/Qwen2.5-0.5B" ,
model_kwargs={
"num_labels": 5,
},
device_map="cuda"
)
print({p.data.dtype for p in pipe.model.parameters()})
# Load + format dataset
dataset = load_dataset("yelp_review_full")["train"].select(range(100))
def tokenize_function(examples):
return pipe.tokenizer(
examples["text"],
max_length=124,
padding="max_length",
truncation=True
)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Train
training_args = TrainingArguments(
per_device_train_batch_size=8,
num_train_epochs=1,
torch_compile=True,
bf16=True, # use bfloat16 mixed precision training
output_dir="/tmp/tests/test_1",
)
trainer = Trainer(
model=pipe.model,
train_dataset=tokenized_datasets,
eval_dataset=tokenized_datasets,
args=training_args,
tokenizer=pipe.tokenizer,
)
trainer.train()
End of traceback:
File /tmp/torchinductor_root/sq/csqz5rruxwlzuuvfjpvwprouxopxgytlrulekcxpejp4ojprvao7.py:637, in call(args)
635 buf20 = empty_strided_cuda((992, 896), (896, 1), torch.float32)
636 # Topologically Sorted Source Nodes: [linear_3], Original ATen: [aten.mm]
--> 637 extern_kernels.mm(reinterpret_tensor(buf19, (992, 896), (896, 1), 0), reinterpret_tensor(primals_12, (896, 896), (1, 896), 0), out=buf20)
638 buf21 = reinterpret_tensor(buf20, (8, 124, 896), (111104, 896, 1), 0); del buf20 # reuse
639 buf22 = empty_strided_cuda((8, 124, 1), (124, 1, 992), torch.float32)
It looks like a bug to me but I want to be sure before opening an issue.