Issues loading NLLB 54B MoE model for multi-GPU inferencing using accelerate


I am trying to run the inference with NLLB 54B MoE model (facebook/nllb-moe-54b 路 Hugging Face) on 4 GPUs using accelerate. I am following the same scripts provided in the BLOOM inference repository (transformers-bloom-inference/bloom-inference-scripts at main 路 huggingface/transformers-bloom-inference 路 GitHub).

Although I am able to successfully load the sharded model, I am facing a device mismatch error when trying to run the inference.

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

System and library versions:

  • python3: 3.8.10
  • torch: 1.13.1+cu116
  • transformers: 4.28.1
  • accelerate: 0.18.0

Here is the script that I am using for the inference:

import argparse
import gc
import math
import os
import time

import torch
import torch.distributed as dist

from accelerate import init_empty_weights, infer_auto_device_map
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer

def get_args():
    parser = argparse.ArgumentParser()
        "--local_rank", required=False, type=int, help="used by dist launchers"
    parser.add_argument("--name", type=str, help="Name path", required=True)
    parser.add_argument("--batch_size", default=1, type=int, help="batch size")
        "--benchmark", action="store_true", help="additionally run benchmark"
    parser.add_argument("--greedy", action="store_true")
    parser.add_argument("--top-k", type=int, default=0)
    parser.add_argument("--top-p", type=float, default=0.0)
        help="float16 or int8",
        choices=["int8", "float16"],

    return parser.parse_args()

t_start = time.time()

num_tokens = 100

args = get_args()

local_rank = int(os.getenv("LOCAL_RANK", "0"))
world_size = torch.cuda.device_count()

rank = local_rank

def print_rank0(*msg):
    if rank != 0:

print_rank0(f"Using {world_size} gpus")
model_name =
print_rank0(f"Loading model {model_name}")

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/data/jaygala/hf_cache")
config = AutoConfig.from_pretrained(model_name, cache_dir="/data/jaygala/hf_cache")

# XXX: can't automatically derive dtype via config's `from_pretrained`
dtype = torch.float16

infer_dtype = args.dtype
if infer_dtype == "int8":
    dtype = torch.int8

kwargs = dict(

def get_world_size() -> int:
    if dist.is_initialized():
        return dist.get_world_size()
        return 1

# balanced_low_0 - because it allows a larger batch size with multiple GPUs
if get_world_size() > 1:
    kwargs["device_map"] = "auto"

if infer_dtype == "int8":
    print_rank0("Using `load_in_8bit=True` to use quanitized model")
    kwargs["load_in_8bit"] = True
    kwargs["torch_dtype"] = dtype

model = AutoModelForSeq2SeqLM.from_pretrained(

if args.benchmark:
    t_ready = time.time()

### Generate

print_rank0(f"*** Starting to generate {num_tokens} tokens with bs={args.batch_size}")

batched_input = [
    'We now have 4-month-old mice that are non-diabetic that used to be diabetic," he added.',
    "Dr. Ehud Ur, professor of medicine at Dalhousie University in Halifax, Nova Scotia and chair of the clinical and scientific division of the Canadian Diabetes Association cautioned that the research is still in its early days."
    "Like some other experts, he is skeptical about whether diabetes can be cured, noting that these findings have no relevance to people who already have Type 1 diabetes."
    "On Monday, Sara Danius, permanent secretary of the Nobel Committee for Literature at the Swedish Academy, publicly announced during a radio program on Sveriges Radio in Sweden the committee, unable to reach Bob Dylan directly about winning the 2016 Nobel Prize in Literature, had abandoned its efforts to reach him.",
    'Danius said, "Right now we are doing nothing. I have called and sent emails to his closest collaborator and received very friendly replies. For now, that is certainly enough."',
    "Previously, Ring's CEO, Jamie Siminoff, remarked the company started when his doorbell wasn't audible from his shop in his garage.",

if args.batch_size > len(batched_input):
    # dynamically extend to support larger bs by repetition
    batched_input *= math.ceil(args.batch_size / len(batched_input))

generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False)

print_rank0(f"Generate args {generate_kwargs}")
inputs = batched_input[: args.batch_size]

def generate():
    """returns a list of zipped inputs, outputs and number of new tokens"""

    input_tokens = tokenizer.batch_encode_plus(
        inputs, return_tensors="pt", padding=True
    for t in input_tokens:
        if torch.is_tensor(input_tokens[t]):
            input_tokens[t] = input_tokens[t].to("cuda")

    outputs = model.generate(

    input_tokens_lengths = [x.shape[0] for x in input_tokens.input_ids]
    output_tokens_lengths = [x.shape[0] for x in outputs]

    total_new_tokens = [
        o - i for i, o in zip(input_tokens_lengths, output_tokens_lengths)
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    return zip(inputs, outputs, total_new_tokens)

print_rank0("*** Running generate")
t_generate_start = time.time()
generated = generate()
t_generate_span = time.time() - t_generate_start
for i, o, _ in generated:

### Benchmark

if args.benchmark:
    # clear cache / free memory

    print_rank0("*** Running benchmark")
    # warm up
    for i in range(1):
        _ = generate()

    # benchmark
    t0 = time.time()
    cycles = 5
    total_new_tokens_generated = 0
    for i in range(cycles):
        generated = generate()
        total_new_tokens_generated += sum(new_tokens for _, _, new_tokens in generated)
    throughput = (time.time() - t0) / (total_new_tokens_generated)
*** Performance stats:
Throughput per token including tokenize: {throughput*1000:.2f} msecs
Start to ready to generate: {t_ready - t_start:.3f} secs
Tokenize and generate {total_new_tokens_generated} (bs={args.batch_size}) tokens: {t_generate_span:.3f} secs
Start to finish: {t_ready - t_start + t_generate_span:.3f} secs


CUDA_VISIBLE_DEVICES=0,1,2,3 python3 --name facebook/nllb-moe-54b --batch_size 1 --dtype int8 --benchmark 2>&1 | tee nllb-moe-accelerate-inference_bs=1.txt