from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
prompt = "i see a big"
checkpoint = Llama-3.2-70B-Instruct"
assistant_checkpoint = "Llama-3.2-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained(checkpoint,device_map="auto")
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint,device_map="auto")
outputs = model.generate(**inputs, assistant_model=assistant_model)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
it is may cause OOM