Compile ML model for AWS Inferentia with flexible input size

I have an ML model from Huggingface, which essentially looks as follows:

import torch
from transformers import BloomTokenizerFast, BloomForCausalLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
model = BloomForCausalLM.from_pretrained("bigscience/bloom-560m").to(device)

text = tokenizer.encode(seed)
inputs, past_key_values = torch.tensor([text[0]]), None

with torch.no_grad():
    while #condition met:
        model_out = model(input_ids=inputs.to(device), past_key_values=past_key_values)
        ...
        # Generate new inputs and go back to the start

Now I would like to deploy this model to Inf1 on AWS Sagemaker see here:

from sagemaker.pytorch.model import PyTorchModel

pytorch_model = PyTorchModel(
    model_data=model_path,
    role=role,
    entry_point="my_entry_point_file.py",
    framework_version="1.5.1",
    py_version="py3",
)

neo_model = pytorch_model.compile(
    target_instance_family="ml_inf1",
    input_shape={"input0": [1, 3, 224, 224]},
    output_path=compiled_model_path,
    framework="pytorch",
    framework_version="1.5.1",
    role=role,
    job_name=compilation_job_name,
)

However, in my case I get

UnexpectedStatusException: Error for Compilation job bloom-compiled-inf-inf1-202304-1921-4203: Failed. Reason: ClientError: CompilationError: Unable to compile model for ml_inf1:', 'No operations were successfully partitioned and compiled to neuron for this model - aborting trace!')  For further troubleshooting common failures please visit: https://docs.aws.amazon.com/sagemaker/latest/dg/neo-troubleshooting-compilation.html

I believe the main problem is the following: Whereas inputs can be considered of length [1,1], the variable past_key_values is much more complex. In this case, it is

  • a tuple of length 24
  • each entry is a tuple itself of length 2
  • the two entries are torch tensors of size [16, 64, 6] and [16, 6, 64]

My question is now, what can I do such that it works on Inf1?

I could imagine that either

  • there is a way to enter the right input_shape, which can be something like {‘var1’: [1,1,28,28], ‘var2’:[1,1,28,28]} (I do not know how to display the more complex tuple-tensor structure as outlined above)
  • or can we split past_key_values such that we can build input_shape easily?

Any suggestions would be very appreciated.