I am training google/vit-base-patch16-384
for image classification.
When I try to modify the optimizers input to Trainer
, I am running into this error:
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
For instance, this setup would throw the error, once the trainer runs a backward pass:
training_args = TrainingArguments(
output_dir="model_output",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
weight_decay=0.0,
max_grad_norm=1.0,
learning_rate=args.learning_rate, # initial learning rate
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_eval_batch_size=args.per_device_eval_batch_size,
num_train_epochs=args.num_train_epochs,
warmup_ratio=args.warmup_ratio,
logging_steps=args.logging_steps,
save_total_limit=args.save_total_limit,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
)
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
num_training_steps = (
len(data["train"]) * args.num_train_epochs // args.per_device_train_batch_size
)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(args.warmup_ratio * num_training_steps),
num_training_steps=num_training_steps,
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=data["train"],
eval_dataset=data["val"],
tokenizer=image_processor,
optimizers=(optimizer, scheduler),
compute_metrics=compute_metrics,
callbacks=[
EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0)
],
)
If I comment out the optimizers
input to Trainer
, then it defaults back to AdamW and the training runs fine. Or, if I set optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)
, then it also runs fine. Actually, a few optimizers work here: AdamW, Adam, Adagrad, Rprop, RMSprop (with momentum=0).
Ones that throw the runtime error: SGD (with and without momentum), ASGD, RMSprop with momentum =/= 0.
That’s all I have tested so far… Any ideas why some optimisers are throwing the error, but some are not? I am trying to reproduce the fine-tuning approach from the ViT paper, where they use SGD.