I am able to inference starcoder on multiple GPU’s when using V100’s or A100’s. To my surprise when I fired up a g5.12xlarge
instance on AWS the model gave garbled output with 4xA10G.
At first I thought it was an SDP kernel problem because the mem_efficient
kernel does not work on V100’s and flash_sdp
is fragile. I disabled all kernels except math
and the output was still garbled. I did, however, get it to work in 2 cases:
- Use 1 GPU for inferencing
- Use device_map = ‘sequential’
Using device_map = ‘auto’ or ‘balanced’ does not work. I would appreciate anyone who can assist. I am using the latest drivers and packages on Ubuntu 22.04.
accelerate==0.33.0
torch==2.4.0
transformers==4.41.2
To reproduce this you can execute this code in a g5.12xlarge
instance.
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
LOG = logging.getLogger(__name__)
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)
# disable all backends except math to prove SDP is not the problem.
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
device_map = 'balanced' # not working
#device_map = 'sequential' # works
# this also happens with bigcode/starcoder
checkpoint = "bigcode/gpt_bigcode-santacoder"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint,
device_map = device_map,
torch_dtype = torch.float16)
inputs = tokenizer.encode("def print_hello_world():",
return_tensors="pt").to("cuda")
outputs = model.generate(inputs,
max_new_tokens=16,
pad_token_id=tokenizer.eos_token_id)
LOG.info("output: %s", outputs[0])
decoded_output = tokenizer.decode(outputs[0])
LOG.info("decoded output: %s", decoded_output)
With device_map = ‘balanced’ we get this:
NFO: output: tensor([ 563, 942, 62, 7196, 62, 3881, 1241, 1241, 26430, 46339,
48209, 39986, 35903, 207, 207, 207, 22006, 22006, 32117, 32117,
47725, 6693, 38324], device='cuda:0')
INFO: decoded output: def print_hello_world():():seudactionTypesCountEqualfortundeepStrictEqual hmctshmctsglyphiconglyphiconApplicationTestsuplicate,"",
With device_map = ‘sequential’ we get this:
INFO: output: tensor([ 563, 942, 62, 7196, 62, 3881, 1241, 258, 942, 372, 7371, 9956,
8657, 479, 185, 563, 942, 62, 7196, 62, 3881, 62, 1379],
device='cuda:0')
INFO: decoded output: def print_hello_world():
print("Hello World!")
def print_hello_world_with
Output of nvidia-smi
:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.06 Driver Version: 555.42.06 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A10G Off | 00000000:00:1B.0 Off | 0 |
| 0% 30C P8 15W / 300W | 1MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA A10G Off | 00000000:00:1C.0 Off | 0 |
| 0% 30C P8 16W / 300W | 1MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 NVIDIA A10G Off | 00000000:00:1D.0 Off | 0 |
| 0% 29C P8 15W / 300W | 1MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 NVIDIA A10G Off | 00000000:00:1E.0 Off | 0 |
| 0% 29C P8 15W / 300W | 1MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+