I am running the following code
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
train_dataset = load_dataset("spider", split="train").select(range(50))
val_dataset = load_dataset("spider", split="validation").select(range(50))
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
eval_dataset=val_dataset,
train_dataset=train_dataset,
dataset_text_field="question",
)
trainer.train()
trainer.evaluate()
Based on this nice gist from @ArthurZ gist link to do LORA on Mamba.
When I try to run evaluate I stumble upon the following error:
warnings.warn(
Traceback (most recent call last):
File "/home/ppol/MaskMamba/src/test.py", line 33, in <module>
trainer.evaluate()
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 3964, in evaluate
output = eval_loop(
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 4158, in evaluation_loop
losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 4374, in prediction_step
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/trainer.py", line 3625, in compute_loss
outputs = model(**inputs)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 194, in forward
return self.gather(outputs, self.output_device)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 217, in gather
return gather(outputs, output_device, dim=self.dim)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 135, in gather
res = gather_map(outputs)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in gather_map
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
File "<string>", line 7, in __init__
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/transformers/utils/generic.py", line 390, in __post_init__
for idx, element in enumerate(iterator):
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 127, in <genexpr>
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
File "/home/ppol/.conda/envs/mask_mamba/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 130, in gather_map
return type(out)(map(gather_map, zip(*outputs)))
TypeError: 'MambaCache' object is not iterable
If you could please point out what I am doing wrong it would be very helpful.