Impossible to use flan-t5-xxl in a batch-transform job

Hello,
I have been trying to use flan-t5-xxl for batch-transform. Since the recent updates that fixed the int8, the model doesn’t fit anymore on a 16GB GPU unfortunately. I have been able to run it on the sagemaker real-time inference endpoint by using int8 quantization and a ml.g5.xlarge instance but it seems like that type of instance is not yet available for batch transforms.

I have been using the following code to load the model ultimately:
AutoModelForSeq2SeqLM.from_pretrained(model_path,device_map='auto'load_in_8bit=True)

I haven’t been able to successfully load the model to do inference on a different instance type such as ml.g4dn.12xlarge on multi GPU since even if I try to manually specify the device_map or the max_memory of each GPU. I always get CUDA OOM errors as the endpoint seems to attempt to load 4 models at the same time.

Would anyone have a trick by any chance?

Hello @imiraoui,

Could you maybe share your inference.py completely? have you tried it on a ec2 instance rather than sagemaker? Aslo can you share how you deployed your endpoint? A potential issue could be that your container starts 1 HTTP worker per GPU → trying to load the model 4x then.

Thanks @philschmid! I deployed the endpoint with the HF GPU inference container ( I upgraded tranformers to 4.26 and accelerate + bnb to the last version when building the docker as well) on a ml.g4dn.12xlarge.

The inference.py file looks like this:

def model_fn(model_dir):
    logger.info(model_dir)
    model_path = f'{model_dir}/'
    logger.info(model_path)
    max_memory = '12GB'
    n_gpus = torch.cuda.device_count()
    max_memory = {i: max_memory for i in range(n_gpus)}
    logger.info(f"n_gpus: {n_gpus}")
    logger.info(str(max_memory))
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path,device_map='auto',load_in_8bit=True,max_memory=max_memory)
    global tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path,return_tensors='pt')
    logger.info('pipeline loaded')
    return model

def input_fn(json_request_data, content_type='application/json'):  
    input_data = json.loads(json_request_data)
    input_text = input_data['input_text']
    return input_text

def predict_fn(inputs, model):
    logger.info(model.hf_device_map)
    inputs = tokenizer(inputs, max_length=724,padding='max_length', truncation=True,return_tensors='pt').input_ids.to('cuda')
    outputs = model.generate(inputs,output_scores=True,return_dict_in_generate=True)
    predictions = tokenizer.batch_decode(outputs.sequences,skip_special_tokens=True)
    probs = torch.stack(outputs.scores, dim=1).softmax(-1)
    max_probs = torch.max(probs,axis=-1)
    probas= max_probs.values.prod(-1)
    probas = probas.cpu()

    return [{"confidence_score":proba.item(),"predicted_text":predicted_text} for proba,predicted_text in zip(probas,predictions)]
    
def output_fn(output, accept='application/json'):
    return json.dumps(output), accept

A potential issue could be that your container starts 1 HTTP worker per GPU → trying to load the model 4x then.

I believe that’s right! Is there a way to prevent this behavior? I can see in the cloudwatch logs that it loads 2-3 out of 14 shards 4 times before crashing…

For real time endpoints you can limit the number of workers to 1 by using model_server_wokers=1, e.g.

huggingface_model = HuggingFaceModel(
   model_data=huggingface_estimator.model_data,
   # model_data="s3://sagemaker-us-east-1-558105141721/huggingface-donut-2023-05-18-15-15-20-2023-05-18-15-15-20-285/output/model.tar.gz"
   role=role, 
   transformers_version="4.26", 
   pytorch_version="1.13", 
   py_version="py39",
   model_server_workers=1
)