Chapter 3 problem

I’m getting an error that says all my inputs are scalars. It would be helpful to see a completed working file, as I’m a bit confused about the order. Here’s my messy code:

#!/bin/env python3

import torch
from transformers import AdamW, AutoModelForSequenceClassification
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import logging
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from transformers import TrainingArguments
from transformers import Trainer

logging.set_verbosity_error()
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint,
    num_labels=2,
)


def tokenize_function(example):
    return tokenizer(
        example["sentence1"],
        example["sentence2"],
        truncation=True,
    )


tokenized_datasets = raw_datasets.map(
    tokenize_function,
    num_proc=4,
    batched=True,
)

tokenized_dataset = tokenized_datasets.rename_column("label", "labels")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


# train_dataloader = DataLoader(
#     tokenized_dataset["train"],
#     batch_size=16,
#     shuffle=True,
#     collate_fn=data_collator,
# )

training_args = TrainingArguments(
    "test-trainer",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    learning_rate=2e-5,
    weight_decay=0.01,
)

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
)

trainer.train()

# optimizer = AdamW(model.parameters())
# loss = model(**batches).loss
# loss.backward()
# optimizer.step()

It’s hard to know what the problem is if you don’t copy the error message as well. I just tried your code and it runs fine on my side.

I hadn’t noticed the colab button at the top. I can get the good code from there and see where my code differs. Thanks.