### System Info
transformers.__version__ = 4.41.0.dev0
Python 3.10.12
unix
…
### Who can help?
@gante
### Information
- [X] The official example scripts
- [ ] My own modified scripts
### Tasks
- [x] An officially supported task in the `examples` folder (such as GLUE/SQuAD, ...)
- [ ] My own task or dataset (give details below)
### Reproduction
When using the code documented [here](https://huggingface.co/docs/transformers/main/en/llm_optims?static-kv=Static+Cache#static-kv-cache-and-torchcompile), we get different generations for the same prompt when generating it in a different batch.
For example if you run the code bellow 1st generation for the prompt hey is different than the 2nd generation.
```python
import os
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers import StaticCache
import torch
def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values):
logits = model(
cur_token,
position_ids=input_pos,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False,
use_cache=True
)[0]
new_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
return new_token
def generate(prompts: list[str], model: LlamaForCausalLM, tokenizer: LlamaTokenizer, num_tokens_to_generate: int = 40) -> list[str]:
global decode_one_tokens
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
batch_size, seq_length = inputs["input_ids"].shape
with torch.no_grad():
past_key_values = StaticCache(
config=model.config, max_batch_size=batch_size, max_cache_len=4096, device=torch_device, dtype=model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + num_tokens_to_generate + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
logits = model(
**inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True
)[0]
next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
generated_ids[:, seq_length] = next_token[:, 0]
# Not using torch.compile to simplify debugging
# decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
cache_position = torch.tensor([seq_length + 1], device=torch_device)
for _ in range(1, num_tokens_to_generate):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values)
generated_ids[:, cache_position] = next_token.int()
cache_position += 1
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
if __name__ == "__main__":
torch_device = "cuda"
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="right")
model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16).to(torch_device)
model.eval()
print(generate(["hey", "yo"], model, tokenizer)[0])
# prints: hey, i'm alex.\ni'm a 20-something year old photographer and filmmaker based in los angeles. i'm a lover of all things cre
print('-'*50 + '\n'*2)
print(generate(["hey", "yo code math"], model, tokenizer)[0])
#prints: hey The 2018-19 season is the 10th season of the National Women's Soccer League (NWSL), the top division of women's soccer
print('-'*50 + '\n'*2)
```
### Expected behavior
The generation for each prompt should not depend on the other examples in the batch.