Huggingface Distributed Training with Accelerate

My Task is to fine-tune a Multi-class text classification model with close to 400 labels. As recommended in various NLP blogs, I decided to fine tune BertForSequenceClassification on custom dataset using Accelerate library from Huggingface. Code is written and executed on AWS SageMaker with 4 GPUs and 24 GB RAM on each GPU. Training hyper-parameters: batch size:64 which on 4 GPUs would mean a batch size of 16 per GPU; learning rate: 1e-5, optimizer: torch.optim.AdamW and linear learning rate scheduler.
I tested the code for 1 step i.e. executed the code end to end, printed training loss and validation loss for sanity check and the execution was successful with all loss/accuracy values printed.

But now when I run the code on full dataset (3500k samples), the notebook cell does not produce any error but doesnot execute also. At times, accelerate doesnot identify all. the 4 GPUs else simply quits without even executing one step of training/validation.

I have these basic issues:
a) Why does accelerate() not recognize all the GPUs ( I can see them printed using nvidia-smi)
b) notebook-launcher quits execution with out any error, are there execution logs being saved that can help identify teh problem
c) I am fine tuning BertForSequenceClassification in Multi-class setting, ideally loss should be computed using CrossEntropyLossFunction, but when I printed the loss for sanity check, it printed the value: tensor(6.06, NllLossBackward0) why doesit show NegativeLogLikelihood in the loss values.

Heres. the relevant code snippet:

def training_loop():
    model_path = "models/category_classifier_bert.pt"
    
    accelerator = Accelerator(mixed_precision="fp16")
    device = accelerator.device
    
    model = BertForSequenceClassification.from_pretrained("bert-base-uncased",num_labels=NUM_LABELS)
    print("Loading model ")
    learning_rate = 1e-5
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    print("Loading optimizer")
    train_dl, val_dl = get_dataloaders()
    total_steps = int((len(train_dl) * EPOCHS)/4*BATCH_SIZE)
    print("total_steps : ",total_steps)
    # Set up the learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(optimizer,
                                                num_warmup_steps=100, # Default value
                                                num_training_steps=total_steps)
    print("Loading model on accelerators .... ")
    model, optimizer, train_dl, val_dl, scheduler = accelerator.prepare(model, optimizer, train_dl, val_dl, scheduler)
    #val_dl = accelerator.prepare(val_dl)
    
    metric = Accuracy(task="multiclass", num_classes=NUM_LABELS).to(device)
    # MulticlassAccuracy(num_classes=NUM_LABELS).to(device)
    print("Training ... ")
    for epoch in range(EPOCHS):
        model.train()
        for step,batch in enumerate(train_dl):
            input_ids = batch["input_ids"].to(device)
            targets = batch["label"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            optimizer.zero_grad()

            output = model(input_ids, attention_mask)
            loss = CrossEntropyLoss(output.logits, targets)

            accelerator.backward(loss)
            optimizer.step()
            
            if step%5000 == 0:
                print(step, loss)               
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                accelerator.save(unwrapped_model.state_dict(), model_path)

        model.eval() 
       # accuracy, num_elements = 0, 0
        val_accuracy = []
        print("Model validation .... ")
        for steps,batch in enumerate(val_dl):
            preds = model(batch["input_ids"], batch["attention_mask"])
            targets = batch["label"].to(device)
            
            all_predictions = accelerator.gather(preds.logits)#.detach().cpu()
            all_targets = accelerator.gather(targets)      
            
            loss = CrossEntropyLoss(preds.logits, targets)
#             print("all_predictions : ",type(all_predictions))
#             print("all_targets : ",type(all_targets))
            step_accuracy = metric(all_predictions, all_targets)
            val_accuracy.append(step_accuracy.items())
            if steps%10000 ==0:
                print(f"Loss : {loss}")
#             accuracy_preds = all_predictions == all_targets
#             print("all_predictions : ",type(accuracy_preds))
#             print("preds ",preds.logits.shape)
#             num_elements += BATCH_SIZE
#             accuracy += accuracy_preds.sum()
#             print("accuracy : ",accuracy)
            
        # We can calculate the overall accuracy
        #eval_metric = accuracy.item() / num_elements
        
        # And we use `accelerator.print` to print on only the main process
        accelerator.print(f'epoch {epoch}: accuracy - {100 * mean(accuracy):.2f}')
        
        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        torch.save(unwrapped_model.state_dict(), model_path)
 
        break

from accelerate import notebook_launcher
notebook_launcher(training_loop, num_processes=4)

Heres the output generated after executing. the above cell

Already tried Restarting the kernel and shutting down/restarting the server.

Heres another cell execution after re-start where only 1 GPU is processed adn cell execution is complete