KeyError: 'loss' during Fine Tuning bert-base-italian-cased for QA

I was finetuning bert-base-italian-cased on SQuAD-it dateset with the following arguments

args = TrainingArguments(
    f"test-squad_it",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    label_names = ["start_positions", "end_positions"]
)
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Trying to do trainer.train() it suddenly throws this error:

KeyError                                  Traceback (most recent call last)
<ipython-input-168-4e078f57a6ea> in <module>()
      1 #We can now finetune our model by just calling the train method:
----> 2 trainer.train()

3 frames
/usr/local/lib/python3.7/dist-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
   1270                         tr_loss += self.training_step(model, inputs)
   1271                 else:
-> 1272                     tr_loss += self.training_step(model, inputs)
   1273                 self.current_flos += float(self.floating_point_ops(inputs))
   1274 

/usr/local/lib/python3.7/dist-packages/transformers/trainer.py in training_step(self, model, inputs)
   1732                 loss = self.compute_loss(model, inputs)
   1733         else:
-> 1734             loss = self.compute_loss(model, inputs)
   1735 
   1736         if self.args.n_gpu > 1:

/usr/local/lib/python3.7/dist-packages/transformers/trainer.py in compute_loss(self, model, inputs, return_outputs)
   1774         else:
   1775             # We don't use .loss here since the model may return tuples instead of ModelOutput.
-> 1776             loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
   1777 
   1778         return (loss, outputs) if return_outputs else loss

/usr/local/lib/python3.7/dist-packages/transformers/file_utils.py in __getitem__(self, k)
   1736         if isinstance(k, str):
   1737             inner_dict = {k: v for (k, v) in self.items()}
-> 1738             return inner_dict[k]
   1739         else:
   1740             return self.to_tuple()[k]

KeyError: 'loss'

I’ve already had a look on similar questions about this same issue, but nothing seems to work and I’m really desperate. Any suggestions?

EDIT

My data are taken from SQuAD-it. I’ve created this dataDict:

DatasetDict({
    train: Dataset({
        features: ['answer_text', 'answer_start', 'title', 'context', 'question', 'answers', 'id'],
        num_rows: 48328
    })
    validation: Dataset({
        features: ['answer_text', 'answer_start', 'title', 'context', 'question', 'answers', 'id'],
        num_rows: 5831
    })
})

And then I’ve preprocessed the data as follows (from a tutorial with little adjustments):

def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if answers["answer_start"] == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"]
            end_char = start_char + len(answers["text"])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

tokenized_datasets = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset["train"].column_names)

I’m quite new in coding and maybe I’m relying too much in other people’s work without understanding every detail.

How is your model created? How is your data processed? It’s hard to help debug the root of the error without seeing those.

1 Like

The model is the bert-base-italian-cased from HuggingFace.

My data are taken from SQuAD-it. I’ve created this dataDict:

DatasetDict({
    train: Dataset({
        features: ['answer_text', 'answer_start', 'title', 'context', 'question', 'answers', 'id'],
        num_rows: 48328
    })
    validation: Dataset({
        features: ['answer_text', 'answer_start', 'title', 'context', 'question', 'answers', 'id'],
        num_rows: 5831
    })
})

And then I’ve preprocessed the data as follows (from a tutorial with little adjustments):

def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if answers["answer_start"] == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"]
            end_char = start_char + len(answers["text"])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

tokenized_datasets = train_dataset.map(prepare_train_features, batched=True, remove_columns=train_dataset["train"].column_names)

I’m quite new in coding and maybe I’m relying too much in other people’s work without understanding every detail.

Can you show us the line you used to create the model? Then to debug a bit, you can do:

for batch in trainer.get_train_dataloader():
    print({k: v.shape for k, v in batch.items()})

This will make sure you can assemble of a batch data and print the shapes of each key. Please execute it and print the result here :slight_smile:

Then we can check the batch goes through the model:

outputs = trainer.model(**batch)
print(outputs)

and see what the outputs look like. With all of that we should be able to diagnose the problem.