Distributed Finetuning with Trainer

Hi,
I am working on a docker instance with 4 X 40GB a100 cards. On a single 40GB card I am unable to fit a single sample through the model for finetuning, so I am trying to finetune with sharding to split the model layers across the cards.
I have my own script using the trainer, like below and execute it with python -m torch.distributed.launch --nproc_per_node 4.

When running the script I can see that the model is split across the 4 GPUs with the fsdp setting below. However, the per_device_train_batch_size loads 4 samples (1 per card) in each step so I get an OOM. I am wondering is it possible to only load one sample total in each step when using 4GPUs, instead of one sample per card.

Let me know if it is better to post this in the pytorch forums.

Thanks!

P.s. I initially asked this question on a model Communities tab here, but I guess it is better suited to the forums.

training_args = TrainingArguments(
    output_dir=args.outdir,
    overwrite_output_dir=True,
    save_total_limit=1,
    do_train=True,
    do_eval=False,
    do_predict=True,
    num_train_epochs=args.epochs,              # total number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
    gradient_accumulation_steps = 256,
    warmup_ratio=0.1,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=args.logdir,            # directory for storing logs
    logging_steps=10,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    report_to="none",
    learning_rate=args.learning_rate,
    seed=99,
    local_rank=os.environ['LOCAL_RANK'],
    dataloader_num_workers = 16,
    gradient_checkpointing=True,
    lr_scheduler_type="cosine",
    fsdp='shard_grad_op', 
    fp16 = True if platform.system()!='Darwin' else False
)