Error when running eval on Mamba LORA with PEFT

I am running the following code

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
train_dataset = load_dataset("spider", split="train").select(range(50))
val_dataset = load_dataset("spider", split="validation").select(range(50))
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='./logs',
    logging_steps=10,
    learning_rate=2e-3
)
lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    eval_dataset=val_dataset,
    train_dataset=train_dataset,
    dataset_text_field="question",
)
trainer.train()
trainer.evaluate()

Based on this nice gist from @ArthurZ gist link to do LORA on Mamba.

When I try to run evaluate I stumble upon the following error:

  warnings.warn(
Traceback (most recent call last):
  File "/home/ppol/MaskMamba/src/test.py", line 33, in <module>
    trainer.evaluate()
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 3964, in evaluate
    output = eval_loop(
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 4158, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 4374, in prediction_step
    loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 3625, in compute_loss
    outputs = model(**inputs)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 194, in forward
    return self.gather(outputs, self.output_device)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 217, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 135, in gather
    res = gather_map(outputs)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in gather_map
    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
  File "<string>", line 7, in __init__
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/utils/generic.py", line 390, in __post_init__
    for idx, element in enumerate(iterator):
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in <genexpr>
    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
  File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 130, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'MambaCache' object is not iterable

If you could please point out what I am doing wrong it would be very helpful.

1 Like

Maybe a bug.

1 Like

nice thanks!

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.