BERT for Generative Chatbot

Creating generative chatbot with BertGenerationEncoder and BertGenerationDecoder like:

encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
bert2bert.to(device)
# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")

and training the model like:

progress_bar = tqdm(range(num_training_steps))

bert2bert.train()

for epch in range(num_epochs):
  for i in range(len(query)):
    input_ids = tokenizer(query[i], add_special_tokens=False, return_tensors="pt").input_ids
    labels = tokenizer(response[i], add_special_tokens = False, return_tensors="pt").input_ids
    loss = bert2bert(input_ids=input_ids.to(device), decoder_input_ids=labels.to(device), labels=labels.to(device)).loss
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)
    progress_bar.set_postfix_str(f'Loss: {loss.item():.5f}')

Here query = [‘Hi’, ‘How are you’]
response = [‘Hello’, ‘I’m good’]

After training the model, getting strange generations like:

input_ids = tokenizer('Hi', add_special_tokens=False, return_tensors="pt").input_ids

#bert2bert.eval()
outputs = bert2bert.generate(input_ids.to(device))

print(tokenizer.decode(outputs[0]))

[CLS] is is is is is is is is is is is is is is is is is is is

Kindly suggest where is the flaw.

Thanks

I’m not particularly familiar with BERT-based generation models but in general, .generate(input_ids, do_sample=True) is an easy way to diversify the output of the model. Bear in mind that the output text can still lack coherence but it will less likely be so repetitive. You can also experiment with the repetition_penalty argument of the generate method as explained here