Can't train Mamba2 with FP16 (Mamba2ForCausalLM)

Can’t train Mamba2 with FP16 using Trainer.

Please post a reproducer, along with the error you’re getting

Here it goes,

To reproduce:

config = AutoConfig.from_pretrained('state-spaces/mamba-130m')
model = MambaForCausalLM(config)
model.to(device)

    training_args = TrainingArguments(
    output_dir=args.output_dir,
    logging_dir='./logs',
    gradient_accumulation_steps=1,  
    save_steps=50000,
    max_steps=1000000, 
    eval_strategy="steps",
    eval_steps=50000,
    logging_strategy="epoch",
    logging_steps=2000,
    learning_rate=1e-4,
    fp16=True, 
    dataloader_num_workers=4,
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512, 
    lr_scheduler_type="constant_with_warmup",
    weight_decay=0.1,
    warmup_steps=2000,
    )

    trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_eval
    )

trainer.train()

Trace:

  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
    main()
  File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 545, in main
    trainer.train()
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
    return inner_training_loop(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2804, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2761, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3666, in evaluate
    output = eval_loop(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 4075, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
    outputs = model(**inputs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
    return model_forward(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 738, in forward
    mamba_outputs = self.backbone(
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 610, in forward
    hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 354, in forward
    hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 310, in forward
    return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 178, in cuda_kernels_forward
    cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
  File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/cache_utils.py", line 1644, in update_conv_state
    conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.

Thanks, this looks like an issue which needs to be fixed in the Transformers library. Can you open an issue on Github? Also do you have an up to date Transformers version?

I’ll open an issue on github, Thanks. Yes, I do have the latest version of transformers.

print(transformers.__version__)
4.44.2