I’m trying to train a language model using google/gemma-2-2b
with the Hugging Face Transformers Trainer
. The same training script works fine for other models like gpt2
and meta-llama/Meta-Llama-3-8B
, but with Gemma-2-2B it fails during evaluation, showing:
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.
Below is the full console output (and the relevant code excerpt at the end). Note that I already attempted the following:
- Setting
attn_implementation='eager'
for Gemma-2-2B. - Switching out of
paged_adamw_32bit
. - (Un)commenting
gradient_checkpointing
.
I still get this dtype mismatch error at eval time. Any ideas on how to resolve or work around this?
Full console output:
Kwargs to run:
{'mode': 'dryrun', 'project': 'self-opt-train-uncompiled-py-2-gsm8k', 'num_train_epochs': 1, 'model_name': 'google/gemma-2-2b', 'today': '2025_m02_d07_t07h_20m_14s', 'tmux_sess_num': None, 'hostname': 'skampere1'}
Setting random seed = 42
vLLM not installed or vllm set seed has a bug, skipping vLLM seed setting.
Currently logged in as: brando
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 5.63it/s]
block_size=1024
len(ds_train)=18612
len(ds_train)=2740
/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
warnings.warn(
/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py:371: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
trainer = Trainer(
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
wandb: WARNING The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. ...
0%| | 0/342 [00:00<?, ?it/s]The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49...
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49...
Traceback (most recent call last):
File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 564, in <module>
fire.Fire(_main)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 135, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 468, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 554, in _main
main_train(kwargs)
File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 383, in main_train
trainer.train()
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 2171, in train
return inner_training_loop(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 2440, in _inner_training_loop
self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 3025, in _evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4076, in evaluate
output = eval_loop(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4270, in evaluation_loop
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4486, in prediction_step
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 3734, in compute_loss
outputs = model(**inputs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 842, in forward
outputs = self.model(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 629, in forward
layer_outputs = decoder_layer(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 299, in forward
hidden_states, self_attn_weights = self.self_attn(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 224, in forward
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/cache_utils.py", line 1717, in update
return update_fn(
File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/cache_utils.py", line 1695, in _static_update
v_out[:, :, cache_position] = value_states
~~~~~^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.
Key snippet where I try to force eager attention for Gemma-2:
torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
if 'gemma-2' not in model_name:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype
).to(device)
else:
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation='eager'
).to(device)
I also switched from paged_adamw_32bit
to standard adamw_torch
, and toggled gradient_checkpointing
:
# gradient_checkpointing=config.get('gradient_checkpointing', True), # known to cause issues
# optim=config.get('optim', 'paged_adamw_32bit'), # switched out of paged optim
But the error persists. Any suggestions on how to fix or debug this Index put
dtype mismatch for Gemma-2? Note it works for LLama3-8b and Gpt2.
Code:
- tfa.py: tfa.py · GitHub
- train.py: train.py · GitHub
- original SO: pytorch - How to fix Index put requires the source and destination dtypes match` with `google/gemma-2-2b` in Transformers? - Stack Overflow
- cross: How to fix Index put requires the source and destination dtypes match` with `google/gemma-2-2b` in Transformers? - PyTorch Forums