### TL:DR
When loading a pretrained base model and then subsequently loading …a trained PeftModel adaptor on `mistralai/Mistral-7B-Instruct-v0.2`, `model.generate()` behaves as expected. However, if the `PeftModel` has is_trainable set to `True`, the output is garbage.
Example:
```python
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.bfloat16)
peft_dir = "./foo"
device = "cuda"
model = PeftModel.from_pretrained(model, peft_dir, is_trainable=True).to(device)
inputs = tokenizer.encode("[INST]Sing a nice song.[/INST]", return_tensors="pt").to(device)
single_output = model.generate(inputs, max_length=5500, do_sample=True)
print(tokenizer.decode(single_output[0]))
```
```console
ID' is isnstyle ;ement L, all Emb
be, ",
encoded....ified, :,get " A,mall >ves Gget,itia +x ' Revals.,! -- and,fc.' 2; A contract The{ thethe.PL I .; type [ I it are; =>;as........ Ch * <'. , yourli ' my a: all Consult:,, N;ly.olid, Al'ep|;s
aest is&,x
' and2 =id-- Itly{: Agreement''Contract ,, nt All; fontia I N'ies Cap FAs many{ the new ,lyown A friend The face, I., I a,w=the Not;:um Ald --,
&
=xt-- ;
' a and;.0foria {'ia_msI forS .,, spliting want Develop work (, - (c theie1fssart --ics This -sF:: ". _--, $,(1;,icper- " " I
1, I-
```
I expected to be able to generate (or, at a minimum, something like generate) during a training procedure. See below.
### More Context
I've implemented (I think) a simplified greedy search as part of a training procedure by following the `_greedy_search` implementation in the `GenerationMixin` class. I'm not completely surprised that there is an issue with calling `generate` since the `@torch.nograd()` annotation annotates the `generate` function. However, I would be surprised if there were a fundamental reason a greedy search could not be performed during training. My implementation, like `generate(),` however, produces garbage during training and I need to know why. My implementation follows:
```python
def simplified_greedy_search(model,
tokenizer,
input_ids,
max_length,
debug:bool = True):
# Initialize variables
generated_ids = []
logits_sequence = []
eos_token_id = tokenizer.eos_token_id
model_kwargs = {}
if debug:
# debugging logs show that .generate produces the same type of garbage as this function.
test_input = tokenizer.decode(input_ids[0])
with open("simplified_greedy_search_input.log", "a") as f:
f.write(test_input + "\n")
# use model.generate to generate the sequence
test = model.generate(input_ids, max_length=max_length, do_sample=True)
test_tokens = tokenizer.decode(test[0])
# append this to a log file
with open("simplified_greedy_search.log", "a") as f:
f.write(extract_content(test_tokens) + "\n")
# Generate tokens until max_length or EOS token is reached
while len(generated_ids) < max_length:
# Prepare inputs
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
# Forward pass
outputs = model(**model_inputs, return_dict=True)
# Get the last token logits
next_token_logits = outputs.logits[:, -1, :]
# Store the logits
logits_sequence.append(next_token_logits)
# Get the most probable token
next_token_id = torch.argmax(next_token_logits, dim=-1)
# Check if EOS token is generated
if next_token_id.item() == eos_token_id:
break
# Add the generated token to the sequence
generated_ids.append(next_token_id.item())
# Update the input_ids
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim=-1)
# Update the model_kwargs for the next iteration
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=model.config.is_encoder_decoder
)
# concatentate the generated ids into a tensor
generated_ids = torch.tensor(generated_ids).to(input_ids.device).reshape((1,-1))
# concatenate the logits into a tensor
logits_sequence = torch.stack(logits_sequence, dim=1).to(input_ids.device)
return generated_ids, logits_sequence
```
### Version info
Name: trl
Version: 0.7.11
Name: peft
Version: 0.9.0
Name: transformers
Version: 4.38.2
Name: torch
Version: 2.2.1
### Who can help?
@pacman100 @younesbelkada @sayakpaul
### Information
- [X] The official example scripts
- [X] My own modified scripts
### Tasks
- [ ] An officially supported task in the `examples` folder
- [X] My own task or dataset (give details below)
### Reproduction
```python
base_model = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.bfloat16)
peft_dir = "./foo"
device = "cuda"
model = PeftModel.from_pretrained(model, peft_dir, is_trainable=True).to(device)
inputs = tokenizer.encode("[INST]Sing a nice song.[/INST]", return_tensors="pt").to(device)
single_output = model.generate(inputs, max_length=5500, do_sample=True)
print(tokenizer.decode(single_output[0]))
```
### Expected behavior
I expect that generate does not generate garbage on trainable models.