training_args = TrainingArguments(
output_dir=“my_awesome_food_model”,
remove_unused_columns=False,
evaluation_strategy=“epoch”,
save_strategy=“epoch”,
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=3,
warmup_ratio=0.1,
logging_steps=10,
load_best_model_at_end=True,
metric_for_best_model=“accuracy”,
push_to_hub=True,
)
Please share the error message?
ImportError Traceback (most recent call last)
in <cell line: 3>()
1 from transformers import TrainingArguments
2
----> 3 training_args = TrainingArguments(
4 output_dir=“./cifar”,
5 per_device_train_batch_size=16,
4 frames
/usr/local/lib/python3.10/dist-packages/transformers/training_args.py in _setup_devices(self)
1785 if not is_sagemaker_mp_enabled():
1786 if not is_accelerate_available(min_version=“0.20.1”):
→ 1787 raise ImportError(
1788 “Using the Trainer
with PyTorch
requires accelerate>=0.20.1
: Please run pip install transformers[torch]
or pip install accelerate -U
”
1789 )
ImportError: Using the Trainer
with PyTorch
requires accelerate>=0.20.1
: Please run pip install transformers[torch]
or pip install accelerate -U
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
“Open Examples” button below.
Even though I try the things I wrote below, it gives an error.
ImportError: Using the Trainer
with PyTorch
requires accelerate>=0.20.1
: Please run pip install transformers[torch]
or pip install accelerate -U
What runtime are you using?
You can check by going to:
Runtime > Change runtime type
i got the same error while fine tuning gpt2 for a QnA dataset.
Train
train(
train_file_path=train_file_path,
model_name=model_name,
output_dir=output_dir,
overwrite_output_dir=overwrite_output_dir,
per_device_train_batch_size=per_device_train_batch_size,
num_train_epochs=num_train_epochs,
save_steps=save_steps
)
when I run this part,it gives the same error, and I ran the given commands accelerate>=0.20.1 etc but still got the same error.