Can’t train Mamba2 with FP16 using Trainer.
Please post a reproducer, along with the error you’re getting
Here it goes,
To reproduce:
config = AutoConfig.from_pretrained('state-spaces/mamba-130m')
model = MambaForCausalLM(config)
model.to(device)
training_args = TrainingArguments(
output_dir=args.output_dir,
logging_dir='./logs',
gradient_accumulation_steps=1,
save_steps=50000,
max_steps=1000000,
eval_strategy="steps",
eval_steps=50000,
logging_strategy="epoch",
logging_steps=2000,
learning_rate=1e-4,
fp16=True,
dataloader_num_workers=4,
per_device_train_batch_size=512,
per_device_eval_batch_size=512,
lr_scheduler_type="constant_with_warmup",
weight_decay=0.1,
warmup_steps=2000,
)
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_eval
)
trainer.train()
Trace:
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 575, in <module>
main()
File "/users/PAS2581/kanaka/research/GrokkedTransformersarewang2024/trying_different_archs/mamba/main.py", line 545, in main
trainer.train()
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 1938, in train
return inner_training_loop(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2804, in _maybe_log_save_evaluate
metrics = self._evaluate(trial, ignore_keys_for_eval)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 2761, in _evaluate
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3666, in evaluate
output = eval_loop(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3857, in evaluation_loop
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 4075, in prediction_step
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/trainer.py", line 3363, in compute_loss
outputs = model(**inputs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 819, in forward
return model_forward(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/accelerate/utils/operations.py", line 807, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
return func(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 738, in forward
mamba_outputs = self.backbone(
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 610, in forward
hidden_states = mixer_block(hidden_states, cache_params=cache_params, cache_position=cache_position)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 354, in forward
hidden_states = self.mixer(hidden_states, cache_params=cache_params, cache_position=cache_position)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 310, in forward
return self.cuda_kernels_forward(hidden_states, cache_params, cache_position)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/models/mamba/modeling_mamba.py", line 178, in cuda_kernels_forward
cache_params.update_conv_state(self.layer_idx, conv_states, cache_position)
File "/users/PAS2581/kanaka/miniconda3/envs/grokk/lib/python3.10/site-packages/transformers/cache_utils.py", line 1644, in update_conv_state
conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
RuntimeError: Index put requires the source and destination dtypes match, got Float for the destination and Half for the source.
Thanks, this looks like an issue which needs to be fixed in the Transformers library. Can you open an issue on Github? Also do you have an up to date Transformers version?
I’ll open an issue on github, Thanks. Yes, I do have the latest version of transformers.
print(transformers.__version__)
4.44.2