Degraded results after loading from checkpoint

Hi everyone,
I’m training an EncoderDecoder model. I do checkpointing during training since I can only use 48 hours of GPUs at max and need to continue training the model after the 48-hour limit.
Whenever I load the latest checkpoint to continue training the model, I see a degradation in the results.
Can you help me to identify what the problem is? I’m using gradient accumulation with step 4 and gradient clipping. Does this gradient accumulation cause any issues when loading the checkpoint?


        lr_scheduler = get_scheduler(
            self.args.scheduler_type,
            optimizer=optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps // self.args.accumulation_steps,
        )

        (
            self.model,
            optimizer,
            train_dataloader,
            eval_dataloader,
            test_dataloader,
            lr_scheduler,
        ) = self.accelerator.prepare(
            self.model,
            optimizer,
            train_dataloader,
            eval_dataloader,
            test_dataloader,
            lr_scheduler,
        )

        # We need to keep track of how many total steps we have iterated over
        overall_step = 0
        # We also need to keep track of the starting epoch so that files are named properly
        starting_epoch = 0

        # We need to load the checkpoint back in before training here with `load_state`
        # The total number of epochs is adjusted based on where the state is being loaded from,
        # as we assume continuation of the same training script
        if self.args.resume_from_checkpoint:
            if self.args.resume_from_checkpoint != 'recent':
                self.accelerator.print(f"Resumed from checkpoint: {self.args.resume_from_checkpoint}")
                self.accelerator.load_state(self.args.resume_from_checkpoint)
                path = os.path.basename(self.args.resume_from_checkpoint)
            else:
                # Get the most recent checkpoint
                self.accelerator.print('Getting the most recent checkpoint.')
                dirs = [os.path.join('checkpoints', f.name) for f in os.scandir('checkpoints') if f.is_dir()]
                dirs.sort(key=os.path.getctime)
                path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
            # Extract `epoch_{i}` or `step_{i}`
            training_difference = os.path.splitext(path)[0]
            self.args.logger.debug(f"training_difference: {training_difference}")

            if "epoch" in training_difference:
                starting_epoch = int(training_difference.replace(os.path.join("checkpoints", "epoch_"), "")) + 1
                resume_step = None
            else:
                resume_step = int(training_difference.replace(os.path.join("checkpoints", "step_"), ""))
                self.accelerator.print(f"resume_step: {resume_step}")
                starting_epoch = resume_step // len(train_dataloader)
                self.accelerator.print(f"starting_epoch: {starting_epoch}")
                resume_step -= starting_epoch * len(train_dataloader)
                self.accelerator.print('Resume_Steps is ', resume_step)
        

        progress_bar = tqdm(range(num_epochs * len(train_dataloader)))

        self.args.logger.info("Begining the training loop.")
        for epoch in range(starting_epoch, num_epochs):

            # Training
            self.model.train()

            if self.args.resume_from_checkpoint and epoch == starting_epoch and resume_step is not None:
            # We need to skip steps until we reach the resumed step
                self.accelerator.print(f"Skipping {resume_step} steps.")
                progress_bar.update(resume_step)
                active_dataloader = self.accelerator.skip_first_batches(train_dataloader, resume_step)
                overall_step += resume_step
            else:
                # After the first iteration though, we need to go back to the original dataloader
                active_dataloader = train_dataloader
            for step, batch in enumerate(active_dataloader):
                src_lan = f'{self.dataset.tokenizer.convert_ids_to_tokens([batch["lang"][0]])[0]}'
                if src_lan == self.args.langs[0]:
                    decoder_start_token_id = (
                        self.dataset.tokenizer.convert_tokens_to_ids(
                            f"<{self.args.langs[1].upper()}>"
                        )
                    )
                else:
                    decoder_start_token_id = (
                        self.dataset.tokenizer.convert_tokens_to_ids(
                            f"<{self.args.langs[0].upper()}>"
                        )
                    )
                decoder_input_ids, decoder_attention_mask = shift_tokens_right(
                    input_ids=batch["labels"].to(torch.int64),
                    pad_token_id=self.dataset.tokenizer.pad_token_id,
                    decoder_start_token_id=decoder_start_token_id,
                )
                with self.accelerator.accumulate(self.model):
                    outputs = self.model(
                        input_ids=batch["input_ids"].to(torch.int64),
                        attention_mask=batch["attention_mask"].to(torch.int64),
                        labels=batch["labels"].to(torch.int64),
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask,
                    )
                    loss = outputs.loss
                    self.accelerator.backward(loss)

                    # Gradient Clipping
                    if self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(self.model.parameters(), 1.0)

                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                if self.accelerator.is_main_process:
                    progress_bar.update(1)
                overall_step += 1

                # We save the model, optimizer, lr_scheduler, and seed states by calling `save_state`
                # These are saved to folders named `step_{overall_step}`
                # Will contain files: "pytorch_model.bin", "optimizer.bin", "scheduler.bin", and "random_states.pkl"
                # If mixed precision was used, will also save a "scalar.bin" file
                if isinstance(checkpointing_steps, int):
                    output_dir = f"step_{overall_step}"
                    if overall_step % checkpointing_steps == 0:
                        output_dir = os.path.join('checkpoints', output_dir)
                        self.accelerator.save_state(output_dir)