BERT for Generative Chatbot

Creating generative chatbot with BertGenerationEncoder and BertGenerationDecoder like:

encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
bert2bert.to(device)
# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")

and training the model like:

progress_bar = tqdm(range(num_training_steps))

bert2bert.train()

for epch in range(num_epochs):
  for i in range(len(query)):
    input_ids = tokenizer(query[i], add_special_tokens=False, return_tensors="pt").input_ids
    labels = tokenizer(response[i], add_special_tokens = False, return_tensors="pt").input_ids
    loss = bert2bert(input_ids=input_ids.to(device), decoder_input_ids=labels.to(device), labels=labels.to(device)).loss
    loss.backward()
    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)
    progress_bar.set_postfix_str(f'Loss: {loss.item():.5f}')

Here query = [‘Hi’, ‘How are you’]
response = [‘Hello’, ‘I’m good’]

After training the model, getting strange generations like:

input_ids = tokenizer('Hi', add_special_tokens=False, return_tensors="pt").input_ids

#bert2bert.eval()
outputs = bert2bert.generate(input_ids.to(device))

print(tokenizer.decode(outputs[0]))

[CLS] is is is is is is is is is is is is is is is is is is is

Kindly suggest where is the flaw.

Thanks