### Describe the bug
Following is my code that I am trying to run, but facing a…n error (have attached the whole error below):
My code:
```
from collections import OrderedDict
import warnings
import flwr as fl
import torch
import numpy as np
import random
from torch.utils.data import DataLoader
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification
from transformers import AdamW
#from transformers import tokenized_datasets
warnings.filterwarnings("ignore", category=UserWarning)
# DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE = "cpu"
CHECKPOINT = "distilbert-base-uncased" # transformer model checkpoint
def load_data():
"""Load IMDB data (training and eval)"""
raw_datasets = load_dataset("yhavinga/imdb_dutch")
raw_datasets = raw_datasets.shuffle(seed=42)
# remove unnecessary data split
del raw_datasets["unsupervised"]
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
# random 100 samples
population = random.sample(range(len(raw_datasets["train"])), 100)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
tokenized_datasets["train"] = tokenized_datasets["train"].select(population)
tokenized_datasets["test"] = tokenized_datasets["test"].select(population)
# tokenized_datasets = tokenized_datasets.remove_columns("text")
# tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets = tokenized_datasets.remove_columns("attention_mask")
tokenized_datasets = tokenized_datasets.remove_columns("input_ids")
tokenized_datasets = tokenized_datasets.remove_columns("label")
tokenized_datasets = tokenized_datasets.remove_columns("text_en")
# tokenized_datasets = tokenized_datasets.remove_columns(raw_datasets["train"].column_names)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainloader = DataLoader(
tokenized_datasets["train"],
shuffle=True,
batch_size=32,
collate_fn=data_collator,
)
testloader = DataLoader(
tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
)
return trainloader, testloader
def train(net, trainloader, epochs):
optimizer = AdamW(net.parameters(), lr=5e-4)
net.train()
for _ in range(epochs):
for batch in trainloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
outputs = net(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
def test(net, testloader):
metric = load_metric("accuracy")
loss = 0
net.eval()
for batch in testloader:
batch = {k: v.to(DEVICE) for k, v in batch.items()}
with torch.no_grad():
outputs = net(**batch)
logits = outputs.logits
loss += outputs.loss.item()
predictions = torch.argmax(logits, dim=-1)
metric.add_batch(predictions=predictions, references=batch["labels"])
loss /= len(testloader.dataset)
accuracy = metric.compute()["accuracy"]
return loss, accuracy
def main():
net = AutoModelForSequenceClassification.from_pretrained(
CHECKPOINT, num_labels=2
).to(DEVICE)
trainloader, testloader = load_data()
# Flower client
class IMDBClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
print("Training Started...")
train(net, trainloader, epochs=1)
print("Training Finished.")
return self.get_parameters(config={}), len(trainloader), {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), len(testloader), {"accuracy": float(accuracy)}
# Start client
fl.client.start_numpy_client(server_address="localhost:8080", client=IMDBClient())
if __name__ == "__main__":
main()
```
Error:
```
Traceback (most recent call last):
File "client_2.py", line 136, in <module>
main()
File "client_2.py", line 132, in main
fl.client.start_numpy_client(server_address="localhost:8080", client=IMDBClient())
File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 208, in start_numpy_client
start_client(
File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 142, in start_client
client_message, sleep_duration, keep_going = handle(
File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/grpc_client/message_handler.py", line 68, in handle
return _fit(client, server_msg.fit_ins), 0, True
File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/grpc_client/message_handler.py", line 157, in _fit
fit_res = client.fit(fit_ins)
File "/home/saurav/.local/lib/python3.8/site-packages/flwr/client/app.py", line 252, in _fit
results = self.numpy_client.fit(parameters, ins.config) # type: ignore
File "client_2.py", line 122, in fit
train(net, trainloader, epochs=1)
File "client_2.py", line 76, in train
for batch in trainloader:
File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 652, in __next__
data = self._next_data()
File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 692, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/saurav/.local/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/home/saurav/.local/lib/python3.8/site-packages/transformers/data/data_collator.py", line 221, in __call__
batch = self.tokenizer.pad(
File "/home/saurav/.local/lib/python3.8/site-packages/transformers/tokenization_utils_base.py", line 2713, in pad
raise ValueError(
ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided ['text']
```
### Steps to reproduce the bug
Run the above code.
### Expected behavior
Don't know, doing it for the first time.
### Environment info
- `datasets` version: 1.12.1
- Platform: Linux-5.4.0-58-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- PyArrow version: 11.0.0