Hi everyone,
I’m training an EncoderDecoder model. I do checkpointing during training since I can only use 48 hours of GPUs at max and need to continue training the model after the 48-hour limit.
Whenever I load the latest checkpoint to continue training the model, I see a degradation in the results.
Can you help me to identify what the problem is? I’m using gradient accumulation with step 4 and gradient clipping. Does this gradient accumulation cause any issues when loading the checkpoint?
lr_scheduler = get_scheduler(
self.args.scheduler_type,
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps // self.args.accumulation_steps,
)
(
self.model,
optimizer,
train_dataloader,
eval_dataloader,
test_dataloader,
lr_scheduler,
) = self.accelerator.prepare(
self.model,
optimizer,
train_dataloader,
eval_dataloader,
test_dataloader,
lr_scheduler,
)
# We need to keep track of how many total steps we have iterated over
overall_step = 0
# We also need to keep track of the starting epoch so that files are named properly
starting_epoch = 0
# We need to load the checkpoint back in before training here with `load_state`
# The total number of epochs is adjusted based on where the state is being loaded from,
# as we assume continuation of the same training script
if self.args.resume_from_checkpoint:
if self.args.resume_from_checkpoint != 'recent':
self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}")
self.accelerator.load_state(self.args.resume_from_checkpoint)
path = os.path.basename(self.args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
self.accelerator.print('Getting the most recent checkpoint.')
dirs = [os.path.join('checkpoints', f.name) for f in os.scandir('checkpoints') if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
self.args.logger.debug(f"training_difference: {training_difference}")
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace(os.path.join("checkpoints", "epoch_"), "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace(os.path.join("checkpoints", "step_"), ""))
self.accelerator.print(f"resume_step: {resume_step}")
starting_epoch = resume_step // len(train_dataloader)
self.accelerator.print(f"starting_epoch: {starting_epoch}")
resume_step -= starting_epoch * len(train_dataloader)
self.accelerator.print('Resume_Steps is ', resume_step)
progress_bar = tqdm(range(num_epochs * len(train_dataloader)))
self.args.logger.info("Begining the training loop.")
for epoch in range(starting_epoch, num_epochs):
# Training
self.model.train()
if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
# We need to skip steps until we reach the resumed step
self.accelerator.print(f"Skipping {resume_step} steps.")
progress_bar.update(resume_step)
active_dataloader = self.accelerator.skip_first_batches(train_dataloader, resume_step)
overall_step += resume_step
else:
# After the first iteration though, we need to go back to the original dataloader
active_dataloader = train_dataloader
for step, batch in enumerate(active_dataloader):
src_lan = f'{self.dataset.tokenizer.convert_ids_to_tokens([batch["lang"][0]])[0]}'
if src_lan == self.args.langs[0]:
decoder_start_token_id = (
self.dataset.tokenizer.convert_tokens_to_ids(
f"<{self.args.langs[1].upper()}>"
)
)
else:
decoder_start_token_id = (
self.dataset.tokenizer.convert_tokens_to_ids(
f"<{self.args.langs[0].upper()}>"
)
)
decoder_input_ids, decoder_attention_mask = shift_tokens_right(
input_ids=batch["labels"].to(torch.int64),
pad_token_id=self.dataset.tokenizer.pad_token_id,
decoder_start_token_id=decoder_start_token_id,
)
with self.accelerator.accumulate(self.model):
outputs = self.model(
input_ids=batch["input_ids"].to(torch.int64),
attention_mask=batch["attention_mask"].to(torch.int64),
labels=batch["labels"].to(torch.int64),
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
)
loss = outputs.loss
self.accelerator.backward(loss)
# Gradient Clipping
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if self.accelerator.is_main_process:
progress_bar.update(1)
overall_step += 1
# We save the model, optimizer, lr_scheduler, and seed states by calling `save_state`
# These are saved to folders named `step_{overall_step}`
# Will contain files: "pytorch_model.bin", "optimizer.bin", "scheduler.bin", and "random_states.pkl"
# If mixed precision was used, will also save a "scalar.bin" file
if isinstance(checkpointing_steps, int):
output_dir = f"step_{overall_step}"
if overall_step % checkpointing_steps == 0:
output_dir = os.path.join('checkpoints', output_dir)
self.accelerator.save_state(output_dir)