Code:
from dataclasses import dataclass
import torch.nn as nn
import torch
import seqeval
import numpy as np
import accelerate
from datasets import load_dataset
from transformers import (
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
AutoTokenizer,
AutoModel,
AutoConfig,
MambaConfig,
MambaModel,
MambaPreTrainedModel,
)
from transformers import DataCollatorForTokenClassification
from transformers.utils import ModelOutput
model_name = "state-spaces/mamba-130m-hf"
dataset = load_dataset("conll2003",)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def tokenize_and_align_labels(ds):
tokenized_inputs = tokenizer(ds["tokens"], truncation=True, is_split_into_words=True, padding=True, return_tensors="pt")
labels = []
for i, label in enumerate(ds["ner_tags"]):
word_ids = tokenized_inputs.word_ids(batch_index=i)
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
label_ids.append(label[word_idx])
else:
label_ids.append(-100)
previous_word_idx = word_idx
labels.append(label_ids)
tokenized_inputs["labels"] = labels
return tokenized_inputs
label_list = dataset["train"].features[f"ner_tags"].feature.names
id2label = {i: name for i, name in enumerate(label_list)}
label2id = {name: i for i, name in enumerate(label_list)}
train_ds = dataset["train"].select(range(100)).map(tokenize_and_align_labels, batched=True)
test_ds = dataset["test"].select(range(100)).map(tokenize_and_align_labels, batched=True)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="pt")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=2)
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = seqeval.compute(predictions=true_predictions, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
@dataclass
class MambaSequenceClassifierOutput(ModelOutput):
loss = None
logits = None
cache_params = None
hidden_states = None
class MambaTokenClassifierOutput(MambaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.backbone = MambaModel(config)
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(
self,
input_ids = None,
inputs_embeds = None,
cache_params = None,
labels = None,
output_hidden_states = None,
return_dict = None,
**kwargs,
):
outputs = self.backbone(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
cache_params=cache_params,
labels=labels,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
).detach()
sequence_output = self.dropout(outputs[0])
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return MambaSequenceClassifierOutput(
loss=loss,
logits=logits,
cache_params=outputs.cache_params,
hidden_states=outputs.hidden_states,
)
model = MambaTokenClassifierOutput.from_pretrained(model_name, use_cache=False, num_labels=len(id2label), id2label=id2label, label2id=label2id)
training_args = TrainingArguments(
output_dir="test",
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=2,
weight_decay=0.01,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=test_ds,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
)
trainer.train()
and the Error i get at the moment:
AttributeError Traceback (most recent call last)
<ipython-input-11-4acd14b4dcb6> in <cell line: 163>()
161 )
162
--> 163 trainer.train()
9 frames
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1910 hf_hub_utils.enable_progress_bars()
1911 else:
-> 1912 return inner_training_loop(
1913 args=args,
1914 resume_from_checkpoint=resume_from_checkpoint,
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2343
2344 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2345 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2346
2347 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
2765 metrics = None
2766 if self.control.should_evaluate:
-> 2767 metrics = self._evaluate(trial, ignore_keys_for_eval)
2768
2769 if self.control.should_save:
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
2728
2729 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2730 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
2731 self._report_to_hp_search(trial, self.state.global_step, metrics)
2732
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
3613
3614 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3615 output = eval_loop(
3616 eval_dataloader,
3617 description="Evaluation",
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
3798
3799 # Prediction step
-> 3800 loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
3801 main_input_name = getattr(self.model, "main_input_name", "input_ids")
3802 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None
/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
4034 return (loss, None, None)
4035
-> 4036 logits = nested_detach(logits)
4037 if len(logits) == 1:
4038 logits = logits[0]
/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
188 "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
189 if isinstance(tensors, (list, tuple)):
--> 190 return type(tensors)(nested_detach(t) for t in tensors)
191 elif isinstance(tensors, Mapping):
192 return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in <genexpr>(.0)
188 "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
189 if isinstance(tensors, (list, tuple)):
--> 190 return type(tensors)(nested_detach(t) for t in tensors)
191 elif isinstance(tensors, Mapping):
192 return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
191 elif isinstance(tensors, Mapping):
192 return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
--> 193 return tensors#.detach()
194
195
AttributeError: 'MambaCache' object has no attribute 'detach'