Why does Transformer (LLaMa 3.1-8B) give different logits during inference for the same sample when used with single versus multi gpu prediction?

I have the following code to do batch prediction for multi GPU:

import torch 
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import set_seed as tf_set_seed

import os, random, json
import numpy as np
from accelerate import Accelerator,  notebook_launcher
from accelerate.utils import gather_object
from accelerate.utils import set_seed as acc_set_seed


def seed_everything(seed=13):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    acc_set_seed(seed)
    tf_set_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed = 13)


INFERENCE_DATA_TEXT = [("text_1","label_1"), ("text_2", "label_2"), ("text_3", "label_3")]

def multi_gpu_pred():
    accelerator = Accelerator()
    accelerator.wait_for_everyone() 
    seed_everything(seed = 13)
    
    model = load_model("./my_lora_checkpoint", # This is my custom function to load and merge LoRA to s Base model (LLaMA 3.1-8B)
                       device = {"": accelerator.process_index}, 
                       num_labels = 2, # number of labels in my  classification problem
                       merge_unload = False)
    
    
    with accelerator.split_between_processes(INFERENCE_DATA_TEXT) as prompts:
    
        res = {"pred_probs": [], "pred_labels": []}

        if accelerator.num_processes > 1:
            if accelerator.is_main_process: BATCH_SIZE = 13
            else: BATCH_SIZE = 18

        BATCHES = [prompts[i:i + BATCH_SIZE] for i in range(0, len(prompts), BATCH_SIZE)]
        print(len(BATCHES[0]))

        pred_probs = []
        pred_labels = []

        for batch in tqdm(BATCHES):
            text_batch = [i[0] for i in batch]
            score_batch = [i[1] for i in batch]
            
            with torch.no_grad():
                inputs = tokenizer(text_batch,truncation= True, max_length=1024, padding="max_length", return_tensors = "pt").to(model.device) # llama 3.1-8B tokenizer
                logits = model(**inputs).logits.cpu().to(torch.float32)
                probs = torch.softmax(logits, dim = 1).numpy()
                res["pred_probs"].extend(probs.tolist())
                res["pred_labels"].extend(probs.argmax(axis = 1).tolist())
                res["auto_answer_uuid"].extend(score_batch)
        
        res = [res]
    
    res = gather_object(res)
    if accelerator.is_main_process:
        final = {"pred_probs": [], "pred_labels": []}
        for res_dict in res:
            for key in res_dict.keys(): final[key].extend(res_dict[key])
            
        with open("./tmp/recent_run_pred.json", "w") as f: json.dump(final, f)


notebook_launcher(multi_gpu_pred, num_processes=8)                            

Issue is that when I do num_processes=1,2,3,4, I see different logit results everytime. Here is the results I got for different GPU settings even though each and every seed, model, data is same.

NOTE: G_i stands for the Number of GPU used for prediction

Start looking at 3rd instance onwards

G_1 = [{'pred_probs': [[[0.19436781108379364, 0.8056321740150452], [0.2782568037509918, 0.7217432260513306]], [[0.953923225402832, 0.046076808124780655], [0.6774740219116211, 0.3225259780883789]], [[0.6740504503250122, 0.325949490070343], [0.7074047327041626, 0.2925952970981598]], [[0.2628418505191803, 0.7371581792831421], [0.08632347732782364, 0.9136765599250793]], [[0.38861799240112305, 0.6113819479942322], [0.5370413661003113, 0.46295857429504395]], [[0.858244001865387, 0.14175599813461304]]], 'pred_labels': [[1, 1], [0, 0], [0, 0], [1, 1], [1, 0], [0]], 'actual_labels': [[0, 1], [0, 0], [0, 0], [1, 1], [1, 1], [0]]}]

G_2 = [{'pred_probs': [[[0.19436781108379364, 0.8056321740150452], [0.2782568037509918, 0.7217432260513306]], [[0.953923225402832, 0.046076808124780655], [0.6774740219116211, 0.3225259780883789]], [[0.6740504503250122, 0.325949490070343], [0.7074047327041626, 0.2925952970981598]]], 'pred_labels': [[1, 1], [0, 0], [0, 0]], 'actual_labels': [[0, 1], [0, 0], [0, 0]]}, {'pred_probs': [[[0.2628418505191803, 0.7371581792831421], [0.08632347732782364, 0.9136765599250793]], [[0.38861799240112305, 0.6113819479942322], [0.5370413661003113, 0.46295857429504395]], [[0.858244001865387, 0.14175599813461304]]], 'pred_labels': [[1, 1], [1, 0], [0]], 'actual_labels': [[1, 1], [1, 1], [0]]}]

G_3 = [{'pred_probs': [[[0.19436781108379364, 0.8056321740150452], [0.2782568037509918, 0.7217432260513306]], [[0.953923225402832, 0.046076808124780655], [0.6774740219116211, 0.3225259780883789]]], 'pred_labels': [[1, 1], [0, 0]], 'actual_labels': [[0, 1], [0, 0]]}, {'pred_probs': [[[0.6740504503250122, 0.325949490070343], [0.7074047327041626, 0.2925952970981598]], [[0.2628418505191803, 0.7371581792831421], [0.08632347732782364, 0.9136765599250793]]], 'pred_labels': [[0, 0], [1, 1]], 'actual_labels': [[0, 0], [1, 1]]}, {'pred_probs': [[[0.38861799240112305, 0.6113819479942322], [0.5370413661003113, 0.46295857429504395]], [[0.858244001865387, 0.14175599813461304]]], 'pred_labels': [[1, 0], [0]], 'actual_labels': [[1, 1], [0]]}]

G_4 = [{'pred_probs': [[[0.19436781108379364, 0.8056321740150452], [0.2782568037509918, 0.7217432260513306]], [[0.9539017081260681, 0.04609827324748039]]], 'pred_labels': [[1, 1], [0]], 'actual_labels': [[0, 1], [0]]}, {'pred_probs': [[[0.6774740219116211, 0.3225259780883789], [0.6740504503250122, 0.325949490070343]], [[0.7057850360870361, 0.29421496391296387]]], 'pred_labels': [[0, 0], [0]], 'actual_labels': [[0, 0], [0]]}, {'pred_probs': [[[0.2628418505191803, 0.7371581792831421], [0.08632347732782364, 0.9136765599250793]], [[0.3942009508609772, 0.6057990193367004]]], 'pred_labels': [[1, 1], [1]], 'actual_labels': [[1, 1], [1]]}, {'pred_probs': [[[0.5370413661003113, 0.46295857429504395], [0.8577680587768555, 0.14223188161849976]]], 'pred_labels': [[0, 0]], 'actual_labels': [[1, 0]]}]

What can be the issue? Am I missing something here?