Hello everybody, I am trying to reproduce the generate
function of the GenerationMixin
class to be able to give manual decoder input. I am using transformers v4.1.1
. While I get nice results using the greedy_search
function, I am not managing to reproduce the beam_search
one, since my RAM overflows. I do not have memory problems using generate
. Hereafter is the code. I am not using any special decoder input for now, only model.config.bos_token_id
. My plan was to check everything worked and change it afterwards. I have tested this with both bart-base
and pegasus-large
, with equal results.
from transformers.generation_utils import GenerationMixin
gm = GenerationMixin
min_length = config.BULLETS_MIN_LEN
max_length = config.BULLETS_MAX_LEN
num_beams = 4
early_stopping = True
no_repeat_ngram_size = 5
num_return_sequences = 1
model_kwargs = {}
pad_token_id = model.config.pad_token_id
eos_token_id = model.config.eos_token_id
decoder_start_token_id = model.config.decoder_start_token_id
bos_token_id = model.config.bos_token_id
# encode the text
input_ids = tokenizer.encode(text, return_tensors='pt')
# prepare attention mask and encoder output
model_kwargs["attention_mask"] = gm._prepare_attention_mask_for_generation(
model, input_ids, pad_token_id, eos_token_id)
if model.config.is_encoder_decoder:
model_kwargs = gm._prepare_encoder_decoder_kwargs_for_generation(model, input_ids, model_kwargs)
input_ids = gm._prepare_decoder_input_ids_for_generation(
model, input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs)
model_kwargs["use_cache"] = None
logits_processor = gm._get_logits_processor(
model,
repetition_penalty=None,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=None,
min_length=min_length,
eos_token_id=None,
prefix_allowed_tokens_fn=None,
num_beams=num_beams,
num_beam_groups=None,
diversity_penalty=None)
Using greedy_search
:
outputs = gm.greedy_search(
model,
input_ids,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs)
print(tokenizer.decode(outputs[0], skip_special_tokens = True))
And the results is equal to
real_outputs = model.generate(tokenizer.encode(
df_cc_group.iloc[0].text[0], return_tensors='pt'),
min_length = min_length,
max_length = max_length,
no_repeat_ngram_size = no_repeat_ngram_size,
num_beams = 1)
print(tokenizer.decode(real_outputs[0], skip_special_tokens = True))
However, using beam_search
my RAM overflows:
from transformers import BeamSearchScorer
batch_size = input_ids.shape[0]
length_penalty = model.config.length_penalty
early_stopping = early_stopping
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
max_length=max_length,
num_beams=num_beams,
device=model.device,
length_penalty=length_penalty,
do_early_stopping=early_stopping,
num_beam_hyps_to_keep=num_return_sequences)
input_ids, model_kwargs = gm._expand_inputs_for_generation(
input_ids,
expand_size=4,
is_encoder_decoder=model.config.is_encoder_decoder,
**model_kwargs)
outputs = gm.beam_search(
model,
input_ids,
beam_scorer,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
**model_kwargs)
Thank you in advance for the help!