How to fix Index put requires the source and destination dtypes match` with `google/gemma-2-2b` in Transformers?

I’m trying to train a language model using google/gemma-2-2b with the Hugging Face Transformers Trainer. The same training script works fine for other models like gpt2 and meta-llama/Meta-Llama-3-8B, but with Gemma-2-2B it fails during evaluation, showing:

RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.

Below is the full console output (and the relevant code excerpt at the end). Note that I already attempted the following:

  • Setting attn_implementation='eager' for Gemma-2-2B.
  • Switching out of paged_adamw_32bit.
  • (Un)commenting gradient_checkpointing.

I still get this dtype mismatch error at eval time. Any ideas on how to resolve or work around this?


Full console output:

Kwargs to run:
{'mode': 'dryrun', 'project': 'self-opt-train-uncompiled-py-2-gsm8k', 'num_train_epochs': 1, 'model_name': 'google/gemma-2-2b', 'today': '2025_m02_d07_t07h_20m_14s', 'tmux_sess_num': None, 'hostname': 'skampere1'}
Setting random seed = 42
vLLM not installed or vllm set seed has a bug, skipping vLLM seed setting.
Currently logged in as: brando

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  5.63it/s]
block_size=1024
len(ds_train)=18612
len(ds_train)=2740
/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py:371: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
  trainer = Trainer(
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
wandb: WARNING The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. ...
  0%|                                                                                                    | 0/342 [00:00<?, ?it/s]The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49...
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49...
Traceback (most recent call last):
  File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 564, in <module>
    fire.Fire(_main)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 135, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 468, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 554, in _main
    main_train(kwargs)
  File "/lfs/skampere1/0/brando9/ZIP-FIT/zip_fit/train/train.py", line 383, in main_train
    trainer.train()
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 2171, in train
    return inner_training_loop(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 2440, in _inner_training_loop
    self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 3025, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4076, in evaluate
    output = eval_loop(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4270, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 4486, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/trainer.py", line 3734, in compute_loss
    outputs = model(**inputs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 842, in forward
    outputs = self.model(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 629, in forward
    layer_outputs = decoder_layer(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 299, in forward
    hidden_states, self_attn_weights = self.self_attn(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/models/gemma2/modeling_gemma2.py", line 224, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/cache_utils.py", line 1717, in update
    return update_fn(
  File "/lfs/skampere1/0/brando9/miniconda/envs/zip_fit/lib/python3.11/site-packages/transformers/cache_utils.py", line 1695, in _static_update
    v_out[:, :, cache_position] = value_states
    ~~~~~^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and BFloat16 for the source.

Key snippet where I try to force eager attention for Gemma-2:

torch_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
if 'gemma-2' not in model_name:
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch_dtype
    ).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        attn_implementation='eager'
    ).to(device)

I also switched from paged_adamw_32bit to standard adamw_torch, and toggled gradient_checkpointing:

# gradient_checkpointing=config.get('gradient_checkpointing', True),  # known to cause issues
# optim=config.get('optim', 'paged_adamw_32bit'),  # switched out of paged optim

But the error persists. Any suggestions on how to fix or debug this Index put dtype mismatch for Gemma-2? Note it works for LLama3-8b and Gpt2.


Code:

1 Like

The quick fix seems to be casting (model.to(torch.float16)), but if it only happens with eval, there may be something wrong with the eval function.