While training a T5Small model using FSDP, the model does not learn

Hi everyone! I hope you are all well. I am having quite the issue training models using Fully Sharded Data Parallelism (FSDP). I first trained a T5Small model for the summarization task using a single A100 GPU with complete success. I did not implement any advanced parallelism in this training. Once I had a proof of concept, I moved on to training the exact same model using FSDP on 2 A100 GPUs; however, this has not gone smoothly. Despite reading up on FSDP documentation, ensuring my proof of concept code and FSDP implementation code were as similar as possible, and studying the working example on the HuggingFace Example Zoo, I am at a loss (pun intended). Here is where I am going to get very specific:

  • I DID get the model training via FSDP
  • The model DID NOT learn during its training (loss stayed pretty constant)
  • I DID test both the proof of concept and the FSDP implementation to confirm the FSDP model has not learned
  • Specifically, I generated predictions for the same input prompt on different epochs, in the proof of concept and FSDP implementations, and I also generated rouge scores to confirm the FSDP model’s refusal to learn.
  • Caveat: I did not get the state to save for resumed sharded training, but the model saved successfully using unwrapped_model.save_pretrained(checkpoint_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))

Any insight or help would be appreciated, I have no idea what I’ve done wrong and just want to get my model learning so I can move onto bigger and better models. I have attached the training script for the FSDP implementation below. Please let me know if there is anything else I can provide to help you help me! Thank you in advance.

FSDP Implementation Training Script:

import argparse
import gc
import os
import threading

import evaluate
import psutil
import torch
from datasets import load_dataset
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from torch.utils.data import DataLoader
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    set_seed,
)
from accelerate import Accelerator, DistributedType, FullyShardedDataParallelPlugin
from accelerate.utils import is_npu_available, is_xpu_available
import transformers
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoConfig
from transformers.utils import send_example_telemetry
from huggingface_hub import Repository, create_repo
from huggingface_hub import get_full_repo_name
import datasets
from datasets import load_dataset
import numpy as np
import pandas
import evaluate
import sys
import os
import time
import argparse
import logging
import math
import json
import nltk
from nltk.tokenize import sent_tokenize
from torch.utils.data import DataLoader
from torch.optim import AdamW, Adafactor
from transformers import get_scheduler
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.tracking import TensorBoardTracker
from accelerate import DistributedDataParallelKwargs
from pynvml import *
from peft import LoraConfig, TaskType
from peft import get_peft_model
from tqdm.auto import tqdm
import torch
import nvidia_smi

# Converting Bytes to Megabytes
def b2mb(x):
    return int(x / 2**20)

# This context manager is used to track the peak memory usage of the process
class TorchTracemalloc:
    def __enter__(self):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
            self.begin = torch.cuda.memory_allocated()
        elif is_xpu_available():
            torch.xpu.empty_cache()
            torch.xpu.reset_max_memory_allocated()  # reset the peak gauge to zero
            self.begin = torch.xpu.memory_allocated()
        elif is_npu_available():
            torch.npu.empty_cache()
            torch.npu.reset_max_memory_allocated()  # reset the peak gauge to zero
            self.begin = torch.npu.memory_allocated()
        self.process = psutil.Process()

        self.cpu_begin = self.cpu_mem_used()
        self.peak_monitoring = True
        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
        peak_monitor_thread.daemon = True
        peak_monitor_thread.start()
        return self

    def cpu_mem_used(self):
        """get resident set size memory for the current process"""
        return self.process.memory_info().rss

    def peak_monitor_func(self):
        self.cpu_peak = -1

        while True:
            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)

            # can't sleep or will not catch the peak right (this comment is here on purpose)
            # time.sleep(0.001) # 1msec

            if not self.peak_monitoring:
                break

    def __exit__(self, *exc):
        self.peak_monitoring = False

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            self.end = torch.cuda.memory_allocated()
            self.peak = torch.cuda.max_memory_allocated()
        elif is_xpu_available():
            torch.xpu.empty_cache()
            self.end = torch.xpu.memory_allocated()
            self.peak = torch.xpu.max_memory_allocated()
        elif is_npu_available():
            torch.npu.empty_cache()
            self.end = torch.npu.memory_allocated()
            self.peak = torch.npu.max_memory_allocated()
        self.used = b2mb(self.end - self.begin)
        self.peaked = b2mb(self.peak - self.begin)

        self.cpu_end = self.cpu_mem_used()
        self.cpu_used = b2mb(self.cpu_end - self.cpu_begin)
        self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin)
        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")


# For testing only
if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
    from accelerate.test_utils.training import mocked_dataloaders

    get_dataloaders = mocked_dataloaders  # noqa: F811

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    # ROUGE expects a newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def training_function(config, args):
    # For GPU temp monitoring
    nvidia_smi.nvmlInit()
    handle0 = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
    handle1 = nvidia_smi.nvmlDeviceGetHandleByIndex(1)

    dateAndTime = os.environ['CURRENT_DATE_TIME']
    nltk.download("punkt")

    # For testing only
    if os.environ.get("TESTING_MOCKED_DATALOADERS", None) == "1":
        config["num_epochs"] = 2
    # Output directory
    output_dir = args.outputDir

    # New Code #
    # Pass the advanced FSDP settings not part of the accelerate config by creating fsdp_plugin
    fsdp_plugin = FullyShardedDataParallelPlugin(
        state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=True), # can set these to true to provide more GPU memory at the cost of computation time
        optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=False, rank0_only=True), # can set these to true to provide more GPU memory at the cost of computation time
    )

    # Initialize accelerator
    if args.with_tracking:
        accelerator = Accelerator(
            cpu=False,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            mixed_precision=args.mixed_precision,
            log_with="tensorboard",
            project_dir=args.logging_dir,
            fsdp_plugin=fsdp_plugin,
        )
    else:
        accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, fsdp_plugin=fsdp_plugin)
    accelerator.print(accelerator.distributed_type)

    @accelerator.on_main_process
    def training_log(epoch, num_epoch, i_iter, epoch_iters, optimizer, loss):
        msg = '\nEpoch: [{}/{}] Iter:[{}/{}], lr: {}, Loss: {:.6f}'.format(
            epoch, num_epoch, i_iter, epoch_iters,
            [x['lr'] for x in optimizer.param_groups], loss)
        print(msg)

    @accelerator.on_main_process
    def print_rouge(epoch, result):
        print(f"Epoch {epoch}:", result)

    if hasattr(args.checkpointing_steps, "isdigit"):
        if args.checkpointing_steps == "epoch":
            checkpointing_steps = args.checkpointing_steps
        elif args.checkpointing_steps.isdigit():
            checkpointing_steps = int(args.checkpointing_steps)
        else:
            raise ValueError(
                f"Argument `checkpointing_steps` must be either a number or `epoch`. `{args.checkpointing_steps}` passed."
            )
    else:
        checkpointing_steps = None
    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = config["lr"]
    num_epochs = int(args.train_epochs)
    seed = int(config["seed"])
    batch_size = int(config["batch_size"])

    # We need to initialize the trackers we use, and also store our configuration
    if args.with_tracking:
        experiment_config = vars(args)
        accelerator.init_trackers("fsdp_pubMed_no_trainer", experiment_config)

    tokenizer = AutoTokenizer.from_pretrained(args.modelAndTokenizerName)
    raw_train_dataset = load_dataset("ccdv/pubmed-summarization", "document", split="train")
    raw_val_dataset = load_dataset("ccdv/pubmed-summarization", "document", split="validation")
    column_names = raw_train_dataset.column_names
    metric = evaluate.load("rouge")

    # Define tokenizer pre-processing function for the dataset
    max_input_length = args.max_input_length # this defines the maximum number of tokens the model can take as input for any given task.
    max_target_length = args.max_target_length
    padding = "max_length"
    truncation = "longest_first"
    def tokenize_function(examples):

        model_inputs = tokenizer(examples["article"], max_length=max_input_length, padding=padding, truncation=True)
        labels = tokenizer(text_target=examples["abstract"], max_length=max_target_length, padding=padding, truncation=True)

        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
        # padding in the loss.
        if padding == "max_length" and args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    # starting with the main process first:
    accelerator.wait_for_everyone()
    with accelerator.main_process_first():
        train_dataset = raw_train_dataset.map(
            tokenize_function, 
            batched=True,
            num_proc=6,
            remove_columns=column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on raw train split"
            )
        val_dataset = raw_val_dataset.map(
            tokenize_function,
            batched=True,
            num_proc=6,
            remove_columns=column_names,
            load_from_cache_file=True,
            desc="Running tokenizer on raw val split"
            )
        train_dataset.set_format("torch")
        val_dataset.set_format("torch")

    # If the batch size is too big we use gradient accumulation
    gradient_accumulation_steps = args.gradient_accumulation_steps
    train_batch_size = args.train_batch_size
    val_batch_size = args.val_batch_size

    set_seed(seed)

    # Instantiate the model (we build the model here so that the seed also control new weights initialization)
    autoConfig = AutoConfig.from_pretrained(args.modelAndTokenizerName)
    model = AutoModelForSeq2SeqLM.from_pretrained(args.modelAndTokenizerName, config=autoConfig)
    #model.gradient_checkpointing_enable() # reduces memory usage during training

    label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer, 
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=8 if accelerator.use_fp16 else None,
    )

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        train_dataset,
        pin_memory=True, # for speeding up training set = true, enables faster transfers between CPU and GPU memory
        shuffle=True,
        collate_fn=data_collator,
        batch_size=train_batch_size,
        num_workers=4 # for speeding up training, spawns several workers to preload the data faster. If GPU utilization is far from 100%, increase number of workers.
    )

    val_dataloader = DataLoader(
        val_dataset,
        pin_memory=True,
        collate_fn=data_collator,
        batch_size=val_batch_size,
        num_workers=4
    )

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.003,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

    optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    # Training loop updates
    num_train_epochs = args.train_epochs
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    max_training_steps = num_train_epochs * num_update_steps_per_epoch

    # Instantiate scheduler
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=max_training_steps,
    )

    # Validation loop updates
    num_val_epochs = args.train_epochs
    num_update_steps_per_epoch_val = math.ceil(len(val_dataloader) / args.gradient_accumulation_steps)
    max_validation_steps = num_val_epochs * num_update_steps_per_epoch_val

    # Prepare accelerator
    model, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, val_dataloader, lr_scheduler
    )

    # Recalculate training loop updates because it changes after prepare method sometimes
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    max_training_steps = num_train_epochs * num_update_steps_per_epoch
    num_train_epochs = math.ceil(max_training_steps / num_update_steps_per_epoch)
    progress_bar = tqdm(range(max_training_steps), disable=not accelerator.is_local_main_process)

    # Recalculate validation loop updates because it changes after prepare method
    num_update_steps_per_epoch_val = math.ceil(len(val_dataloader) / args.gradient_accumulation_steps)
    max_validation_steps = num_val_epochs * num_update_steps_per_epoch_val
    num_val_epochs = math.ceil(max_validation_steps / num_update_steps_per_epoch_val)
    val_progress_bar = tqdm(range(max_validation_steps), disable=not accelerator.is_local_main_process)

    completed_steps = 0
    val_completed_steps = 0
    progress_bar.update(completed_steps)
    val_progress_bar.update(val_completed_steps)


    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
            accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
            accelerator.load_state(args.resume_from_checkpoint)
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
            dirs.sort(key=os.path.getctime)
            path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
        # Extract `epoch_{i}` or `step_{i}`
        training_difference = os.path.splitext(path)[0]

        if "epoch" in training_difference:
            num_epochs -= int(training_difference.replace("epoch_", ""))
            resume_step = None
        else:
            resume_step = int(training_difference.replace("step_", ""))
            num_epochs -= resume_step // len(train_dataloader)
            # If resuming by step, we also need to know exactly how far into the DataLoader we went
            resume_step = (num_epochs * len(train_dataloader)) - resume_step

    # Now we train the model
    for epoch in range(num_train_epochs):
        completed_steps=0
        progress_bar.update(completed_steps)
        # New Code #
        # context manager to track the peak memory usage during the training epoch
        with TorchTracemalloc() as tracemalloc:
            model.train()
            if args.with_tracking:
                total_loss = 0
            for step, batch in enumerate(train_dataloader):
                temp0 = nvidia_smi.nvmlDeviceGetTemperature(handle0, nvidia_smi.NVML_TEMPERATURE_GPU)
                temp1 = nvidia_smi.nvmlDeviceGetTemperature(handle1, nvidia_smi.NVML_TEMPERATURE_GPU)
                if temp0 >= 75 or temp1 >= 75: 
                    print(f"\nGPU 0 Temperature: {temp0}C")
                    print(f"\nGPU 1 Temperature: {temp1}C")
                    raise RuntimeError("GPU Temperature is too high!")
                # We need to skip steps until we reach the resumed step
                if args.resume_from_checkpoint and epoch == 0:
                    if resume_step is not None and step < resume_step:
                        pass
                with accelerator.accumulate(model):
                    # We could avoid this line since we set the accelerator with `device_placement=True`.
                    batch.to(accelerator.device)
                    outputs = model(**batch)
                    loss = outputs.loss
                    # We keep track of the loss at each epoch
                    if args.with_tracking:
                        total_loss += loss.detach().float()
                    accelerator.backward(loss)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                # Save checkpoint on specific iteration when using StateDictType is SHARDED_STATE_DICT
                if isinstance(checkpointing_steps, int):
                    if completed_steps % checkpointing_steps == 0:
                        ckpt_step_dir = f"{dateAndTime}/accel_state_ckpt_step_{completed_steps}"
                        if output_dir is not None:
                            accel_state_dir = os.path.join(output_dir, ckpt_step_dir)
                        accelerator.save_state(accel_state_dir)
                
                # Check if the accelerator has performed an optimization step behind the scenes
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    completed_steps += 1

                accelerator.wait_for_everyone()
                training_log(epoch, num_train_epochs, completed_steps, max_training_steps, optimizer, loss)
                accelerator.wait_for_everyone()

                if completed_steps >= 50: #max_training_steps: # Edited this to speed up my evaulation
                    break
                

        # New Code #
        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
        accelerator.print(f"Memory before entering the train : {b2mb(tracemalloc.begin)}")
        accelerator.print(f"Memory consumed at the end of the train (end-begin): {tracemalloc.used}")
        accelerator.print(f"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}")
        accelerator.print(
            f"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
        )
        # Logging the peak memory usage of the GPU to the tracker
        if args.with_tracking:
            accelerator.log(
                {
                    "train_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
                },
                step=epoch,
            )

        val_completed_steps=0
        val_progress_bar.update(val_completed_steps)
        # New Code #
        # context manager to track the peak memory usage during the evaluation
        with TorchTracemalloc() as tracemalloc:
            model.eval()
            for step, batch in enumerate(val_dataloader):
                with accelerator.accumulate(model):
                    batch.to(accelerator.device)
                    with torch.no_grad():
                        outputs = model(**batch)
                        loss = outputs.loss
                        predictions = outputs.logits.argmax(dim=-1)
                        accelerator.wait_for_everyone()
                        predictions, targets = accelerator.gather_for_metrics((predictions, batch["labels"]))

                        # Send to cpu for conversion to numpy
                        predictions = predictions.cpu().numpy()
                        targets = targets.cpu().numpy()
                        # Replace -100 in the references (targets) since we can't decode them
                        targets = np.where(targets != -100, targets, tokenizer.pad_token_id)
                        if isinstance(predictions, tuple):
                            predictions = predictions[0]
                        decoded_preds = tokenizer.batch_decode(
                            predictions, skip_special_tokens=True
                        )
                        decoded_targets = tokenizer.batch_decode(targets, skip_special_tokens=True)

                        decoded_preds, decoded_targets = postprocess_text(
                            decoded_preds, decoded_targets
                        )

                        metric.add_batch(predictions=decoded_preds, references=decoded_targets)
                if accelerator.sync_gradients:
                    val_progress_bar.update(1)
                    val_completed_steps += 1 
                training_log(epoch, num_epochs, val_completed_steps, max_validation_steps, optimizer, loss)
                # Specify how many iterations contribute towards validation rouge scores metric
                if val_completed_steps >= 25: #max_validation_steps: # edited this to speed up my evaluation
                    break

            # Compute metrics
            result = metric.compute(use_stemmer=True)

            # Extract the median ROUGE scores
            result = {k: round(v * 100, 4) for k, v in result.items()}
            accelerator.wait_for_everyone()
            print_rouge(epoch, result)
            accelerator.wait_for_everyone()

            if args.with_tracking:
                result["train_loss"] = total_loss.item() / len(train_dataloader)
                result["epoch"] = epoch
                result["step"] = completed_steps
                accelerator.log(result, step=completed_steps)    
                
        # New Code #
        # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage
        accelerator.print(f"Memory before entering the eval : {b2mb(tracemalloc.begin)}")
        accelerator.print(f"Memory consumed at the end of the eval (end-begin): {tracemalloc.used}")
        accelerator.print(f"Peak Memory consumed during the eval (max-begin): {tracemalloc.peaked}")
        accelerator.print(
            f"Total Peak Memory consumed during the eval (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}"
        )
        # Logging the peak memory usage of the GPU to the tracker
        if args.with_tracking:
            accelerator.log(
                {
                    "eval_total_peak_memory": tracemalloc.peaked + b2mb(tracemalloc.begin),
                },
                step=epoch,
            )
        # Save and upload
        if epoch < num_train_epochs:
            accelerator.wait_for_everyone()
            if accelerator.is_local_main_process:
                print(f"Saving checkpoint for epoch {epoch+1}")
            checkpoint_dir = f"{output_dir}/{dateAndTime}/checkpoints/epoch-{epoch+1}"
            os.makedirs(checkpoint_dir, exist_ok=True)
            accelerator.wait_for_everyone() 
            unwrapped_model = accelerator.unwrap_model(model)
            unwrapped_model.save_pretrained(checkpoint_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
            # Also save to default directory
            unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
            if accelerator.is_main_process:
                tokenizer.save_pretrained(checkpoint_dir)
                tokenizer.save_pretrained(output_dir)

    if output_dir is not None:
        accelerator.wait_for_everyone() 
        if accelerator.state.fsdp_plugin is not None:
            accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(checkpoint_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save)
        # Also save to default directory
        unwrapped_model.save_pretrained(output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
        if accelerator.is_main_process:
            tokenizer.save_pretrained(checkpoint_dir)
            tokenizer.save_pretrained(output_dir)

            all_results = {f"eval_{k}": v for k, v in result.items()}
            with open(os.path.join(output_dir, "all_results.json"), "w") as f:
                json.dump(all_results, f)

    nvidia_smi.nvmlShutdown()
    

    if args.with_tracking:
        accelerator.end_training()
    myDir = f"{output_dir}/{dateAndTime}/checkpoints/"
    return myDir


def main():
    parser = argparse.ArgumentParser(description="Simple example of training script.")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16", "fp8"],
        help="Whether to use mixed precision. Choose"
        "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
        "and an Nvidia Ampere GPU.",
    )
    parser.add_argument("--cpu", action="store_true", help="If passed, will train on the CPU.")
    parser.add_argument(
        "--checkpointing_steps",
        type=str,
        default=None,
        help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help="If the training should continue from a checkpoint folder.",
    )
    parser.add_argument(
        "--with_tracking",
        action="store_true",
        help="Whether to load in all available experiment trackers from the environment and use them for logging.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help="Location on where to store experiment tracking logs`",
    )
    parser.add_argument(
        "--gradient_accumulation_steps", 
        type=int, 
        required=False,
    )
    parser.add_argument(
        "--train_batch_size", 
        type=int, 
        required=True,
    )
    parser.add_argument(
        "--val_batch_size", 
        type=int, 
        required=True
    )
    parser.add_argument(
        "--learning_rate", 
        type=float, 
        required=True
    )
    parser.add_argument(
        "--weight_decay", 
        type=float, 
        required=True
    )
    parser.add_argument(
        "--train_epochs",
        type=int,
        required=True
    )
    parser.add_argument(
        "--outputDir",
        type=str,
        default=".",
        help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.",
    )
    parser.add_argument(
        "--repoName", 
        type=str, 
        required=True
    )
    parser.add_argument(
        "--finetune", 
        type=str, 
        required=True
    )
    parser.add_argument(
        "--existsBranch",
        type=str, 
        required=False
    )
    parser.add_argument(
        "--newBranch", 
        type=str, 
        required=False
    )
    parser.add_argument(
        "--modelAndTokenizerName",
        type=str,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--ignore_pad_token_for_loss",
        type=bool,
        default=True,
        help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
    )
    parser.add_argument(
        "--max_input_length",
        type=int,
        help="Maximum input sequence length for inference",
        required=True,
    )
    parser.add_argument(
        "--max_target_length",
        type=int,
        help="Maximum output target sequence length for prediction",
        required=True,
    )

    args = parser.parse_args()
    config = {"lr": args.learning_rate, "num_epochs": args.train_epochs, "seed": 42, "batch_size": args.train_batch_size}
    myDir = training_function(config, args)

    # Evaluate checkpoints
    input_prompt = "Insert a really long prompt for summarization"
    for folder in os.listdir(myDir):
        ckpt = os.path.join(myDir, folder)
        hub_model_id = ckpt
        tokenizer = AutoTokenizer.from_pretrained(hub_model_id)
        input_ids = tokenizer.encode(input_prompt, return_tensors="pt")
        model = AutoModelForSeq2SeqLM.from_pretrained(hub_model_id)

        output = model.generate(
        input_ids,
        max_length=512,  # Generate up to 50 new tokens
        min_length=250,      # Ensure the total length is at least 30 tokens
        length_penalty=1.0, # No length penalty
        no_repeat_ngram_size=2, # Prevent repeating n-grams of size 2
        num_beams=2,
        early_stopping=True # Stop when the first beam hypothesis reaches EOS
    )
        print(tokenizer.decode(output[0], skip_special_tokens=True))
    
    print(f"Training Finished")



if __name__ == "__main__":
    main()

Update, I was able to get the model training by setting the use_orig_params=True in the FSDP config; however, when transitioning to a larger model (such as LongT5), the loss stays constant before eventually yielding NaN values. I will continue investigating how sharding may affect LongT5 training and report back when I have a solution.