main.py:
def configure_parameters(model):
for param in model.parameters():
param.requires_grad = False
for i in [6, 7, 8, 9, 10, 11]:
for param in model.timesformer.encoder.layer[i].parameters():
param.requires_grad = True
model.timesformer.layernorm.weight.requires_grad = True
model.timesformer.layernorm.bias.requires_grad = True
model.classifier.weight.requires_grad = True
model.classifier.bias.requires_grad = True
return model
base_model = None
resume_path = None
num_epochs = 5
warmup_epochs = 1
resume_choice = input("Do you want to resume training from a checkpoint? (y/n, default: n): ").strip().lower()
if resume_choice in ["y", "yes"]:
target_epoch = int(input("Enter the epoch number to resume from (e.g., 3): ").strip())
checkpoint_folders = sorted([f for f in os.listdir("./ckpt") if f.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1]))
for folder in checkpoint_folders:
trainer_state_path = os.path.join("./ckpt", folder, "trainer_state.json")
if os.path.exists(trainer_state_path):
with open(trainer_state_path, 'r') as f:
trainer_state = json.load(f)
current_epoch = int(trainer_state.get("epoch", 0))
if current_epoch == target_epoch:
resume_path = os.path.join("./ckpt", folder)
print(f"Resuming from epoch {target_epoch} at {resume_path}")
break
if resume_path:
base_model = TimesformerForVideoClassification.from_pretrained(resume_path)
# config = TimesformerConfig.from_pretrained(resume_path)
# base_model = TimesformerForVideoClassification(config)
base_model = configure_parameters(base_model)
while True:
try:
total_epochs = int(input(f"Enter total epochs (including current {target_epoch}): "))
if total_epochs <= target_epoch:
print(f"Total epochs must be greater than {target_epoch}.")
else:
remaining_epochs = total_epochs - target_epoch
num_epochs = remaining_epochs
break
except ValueError:
print("Invalid input. Enter a number.")
else:
print(f"No checkpoint for epoch {target_epoch}")
resume_choice = "n"
if not resume_path:
base_model = load_model("./model")
if not base_model:
print("Failed to load model")
sys.exit(1)
base_model = configure_parameters(base_model)
try:
num_epochs = int(input("Training epochs (default 5): ") or 5)
except ValueError:
print("Invalid input, using default 5")
num_epochs = 5
while True:
try:
warmup_input = input(f"Enter number of warmup epochs (default: 1, max: {int(num_epochs)}): ").strip()
warmup_epochs = int(warmup_input) if warmup_input else 1
if warmup_epochs < 0:
print("Warmup epochs cannot be negative. Using default 1.")
warmup_epochs = 1
break
elif warmup_epochs > num_epochs:
print(f"Warmup epochs cannot exceed total epochs. Setting to {int(num_epochs)}.")
warmup_epochs = int(num_epochs)
break
else:
break
except ValueError:
print("Invalid input. Please enter an integer.")
trained_model = train_model(
base_model,
train_dataset,
val_dataset,
num_epochs=num_epochs,
warmup_epochs=warmup_epochs,
resume_from_checkpoint=resume_path
)
save_finetuned_model(trained_model, "./weights")
print("Training complete. Model saved.")
train_model.py:
os.environ["WANDB_PROJECT"] = "deepfake-detection"
accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")
roc_auc_metric = evaluate.load("roc_auc")
class EpochProgressCallback(TrainerCallback):
def on_epoch_begin(self, args, state, control, **kwargs):
current = int(state.epoch) + 1
total = int(args.num_train_epochs)
print(f"\n\n>>> Starting epoch {current}/{total}")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
accuracy = accuracy_metric.compute(references=labels, predictions=predictions)
precision = precision_metric.compute(predictions=predictions, references=labels, average="binary")
recall = recall_metric.compute(references=labels, predictions=predictions, average="binary", zero_division=0)
f1 = f1_metric.compute(predictions=predictions, references=labels, average="binary")
logits_t = torch.from_numpy(logits)
probs = F.softmax(logits_t.float(), dim=1).cpu().numpy()
auc = roc_auc_metric.compute(prediction_scores=probs[:, 1], references=labels)
try:
wandb.log({"roc": wandb.plot.roc_curve(labels, probs, labels=["real", "fake"])})
wandb.log({"pr": wandb.plot.pr_curve(labels, probs, labels=["real", "fake"])})
except Exception as e:
print(f"Warning: Failed to log to wandb: {e}")
metrics = {
"accuracy": accuracy["accuracy"],
"precision": precision["precision"],
"recall": recall["recall"],
"f1": f1["f1"],
"auc": auc["roc_auc"]
}
return metrics
def train_model(
model,
train_dataset,
val_dataset,
num_epochs,
warmup_epochs,
resume_from_checkpoint=None
):
per_device_batch_size = 8
total_steps = num_epochs * (len(train_dataset) // per_device_batch_size)
warmup_steps = warmup_epochs * (len(train_dataset) // per_device_batch_size)
training_args = TrainingArguments(
output_dir="./ckpt",
overwrite_output_dir=True,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
optim="adamw_torch",
learning_rate=1.5e-5,
weight_decay=0.01,
label_smoothing_factor=0.1,
max_grad_norm=1.0,
gradient_accumulation_steps=1,
lr_scheduler_type="cosine",
num_train_epochs=num_epochs,
warmup_steps=warmup_steps,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
dataloader_num_workers=4,
dataloader_pin_memory=True,
dataloader_persistent_workers=True,
dataloader_prefetch_factor=4,
fp16=True,
disable_tqdm=False,
report_to='wandb',
run_name="TALL-TimeSformer-Tesla V100-Dropout(0.2)"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=default_data_collator,
compute_metrics=compute_metrics,
callbacks=[EpochProgressCallback()]
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
model.save_pretrained('./weights/best_model')
return model
ckpt_dir:
These are the checkpoints for 10 epochs. Trainer saves the checkpoint using global step.
When training from scratch, we need to specify the model in the Trainer and pass model.safetensors
along with the corresponding config.json
, using:
configuration = TimesformerConfig()
model = TimesformerModel(configuration)
My question is: when resuming training from a checkpoint, do I only need to pass resume_from_checkpoint
, or do I have to manually load model.safetensors
from the checkpoint folder and pass it separately?
Does the model specified in the Trainer’s model parameter refer to the initial model, or does it load the model from the checkpoint folder?
I’m a little confused about this. I hope I’ve explained my query well enough. If anything is unclear, please ask. I need to resume training from the checkpoint.