Mamba for token classification task

Hey,

My goal is to use the mamba model for the task of token classification and I’m a bit lost at this point. I use a dataset which is similar to the conll2003 dataset.

The tries I have done till now where all based on these resources:

i will add a bit of code later…

It would be nice to get some guidance to start fresh.

Code:

from dataclasses import dataclass

import torch.nn as nn
import torch
import seqeval
import numpy as np
import accelerate

from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification, 
    Trainer, 
    TrainingArguments,
    AutoTokenizer,
    AutoModel,
    AutoConfig,
    MambaConfig,
    MambaModel,
    MambaPreTrainedModel,
)
from transformers import DataCollatorForTokenClassification
from transformers.utils import ModelOutput

model_name = "state-spaces/mamba-130m-hf"


dataset = load_dataset("conll2003",)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize_and_align_labels(ds):
    tokenized_inputs = tokenizer(ds["tokens"], truncation=True, is_split_into_words=True, padding=True, return_tensors="pt")

    labels = []
    for i, label in enumerate(ds["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx: 
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

label_list = dataset["train"].features[f"ner_tags"].feature.names
id2label = {i: name for i, name in enumerate(label_list)}
label2id = {name: i for i, name in enumerate(label_list)}
train_ds = dataset["train"].select(range(100)).map(tokenize_and_align_labels, batched=True)
test_ds = dataset["test"].select(range(100)).map(tokenize_and_align_labels, batched=True)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="pt")


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }


@dataclass
class MambaSequenceClassifierOutput(ModelOutput):
    loss = None
    logits = None
    cache_params = None
    hidden_states = None


class MambaTokenClassifierOutput(MambaPreTrainedModel):
    def __init__(self, config): 
        super().__init__(config) 
        self.num_labels = config.num_labels 

        self.backbone = MambaModel(config)
        self.dropout = nn.Dropout(0.1) 
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(
        self,
        input_ids = None,
        inputs_embeds = None,
        cache_params = None,
        labels = None,
        output_hidden_states = None,
        return_dict = None,
        **kwargs,
    ):
        outputs = self.backbone(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            cache_params=cache_params,
            labels=labels,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        ).detach()
        sequence_output = self.dropout(outputs[0])
        logits = self.classifier(sequence_output)

        
        loss = None
        if labels is not None:
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MambaSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            cache_params=outputs.cache_params,
            hidden_states=outputs.hidden_states,
        )

model = MambaTokenClassifierOutput.from_pretrained(model_name, use_cache=False, num_labels=len(id2label), id2label=id2label, label2id=label2id)

training_args = TrainingArguments(
    output_dir="test",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

and the Error i get at the moment:

AttributeError                            Traceback (most recent call last)

<ipython-input-11-4acd14b4dcb6> in <cell line: 163>()
    161 )
    162 
--> 163 trainer.train()

9 frames

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1910                 hf_hub_utils.enable_progress_bars()
   1911         else:
-> 1912             return inner_training_loop(
   1913                 args=args,
   1914                 resume_from_checkpoint=resume_from_checkpoint,

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2343 
   2344             self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2345             self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2346 
   2347             if DebugOption.TPU_METRICS_DEBUG in self.args.debug:

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2765         metrics = None
   2766         if self.control.should_evaluate:
-> 2767             metrics = self._evaluate(trial, ignore_keys_for_eval)
   2768 
   2769         if self.control.should_save:

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2728 
   2729     def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2730         metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2731         self._report_to_hp_search(trial, self.state.global_step, metrics)
   2732 

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3613 
   3614         eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3615         output = eval_loop(
   3616             eval_dataloader,
   3617             description="Evaluation",

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   3798 
   3799             # Prediction step
-> 3800             loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   3801             main_input_name = getattr(self.model, "main_input_name", "input_ids")
   3802             inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

/usr/local/lib/python3.10/dist-packages/transformers/trainer.py in prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   4034             return (loss, None, None)
   4035 
-> 4036         logits = nested_detach(logits)
   4037         if len(logits) == 1:
   4038             logits = logits[0]

/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
    188     "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
    189     if isinstance(tensors, (list, tuple)):
--> 190         return type(tensors)(nested_detach(t) for t in tensors)
    191     elif isinstance(tensors, Mapping):
    192         return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})

/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in <genexpr>(.0)
    188     "Detach `tensors` (even if it's a nested list/tuple/dict of tensors)."
    189     if isinstance(tensors, (list, tuple)):
--> 190         return type(tensors)(nested_detach(t) for t in tensors)
    191     elif isinstance(tensors, Mapping):
    192         return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})

/usr/local/lib/python3.10/dist-packages/transformers/trainer_pt_utils.py in nested_detach(tensors)
    191     elif isinstance(tensors, Mapping):
    192         return type(tensors)({k: nested_detach(t) for k, t in tensors.items()})
--> 193     return tensors#.detach()
    194 
    195 

AttributeError: 'MambaCache' object has no attribute 'detach'

This works for me:

@dataclass
class MambaSequenceClassifierOutput(ModelOutput):
    loss = None
    logits = None
    cache_params = None
    hidden_states = None

class MambaTokenClassifierOutput(MambaPreTrainedModel):
    def __init__(self, config): 
        super().__init__(config) 
        self.num_labels = config.num_labels 

        self.backbone = MambaModel(config)
        self.dropout = nn.Dropout(0.1) 
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(
        self,
        input_ids = None,
        inputs_embeds = None,
        cache_params = None,
        labels = None,
        output_hidden_states = None,
        return_dict = None,
        **kwargs,
    ):
        outputs = self.backbone(
            input_ids=input_ids,
            inputs_embeds=inputs_embeds,
            cache_params=cache_params,
            labels=labels,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        last_hidden_states = outputs[0]
        sequence_output = self.dropout(last_hidden_states)
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return MambaSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            cache_params=outputs.cache_params,
            hidden_states=outputs.hidden_states,
        )