Creating generative chatbot with BertGenerationEncoder and BertGenerationDecoder like:
encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102)
# add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102)
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)
bert2bert.to(device)
# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
and training the model like:
progress_bar = tqdm(range(num_training_steps))
bert2bert.train()
for epch in range(num_epochs):
for i in range(len(query)):
input_ids = tokenizer(query[i], add_special_tokens=False, return_tensors="pt").input_ids
labels = tokenizer(response[i], add_special_tokens = False, return_tensors="pt").input_ids
loss = bert2bert(input_ids=input_ids.to(device), decoder_input_ids=labels.to(device), labels=labels.to(device)).loss
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
progress_bar.set_postfix_str(f'Loss: {loss.item():.5f}')
Here query = [‘Hi’, ‘How are you’]
response = [‘Hello’, ‘I’m good’]
After training the model, getting strange generations like:
input_ids = tokenizer('Hi', add_special_tokens=False, return_tensors="pt").input_ids
#bert2bert.eval()
outputs = bert2bert.generate(input_ids.to(device))
print(tokenizer.decode(outputs[0]))
[CLS] is is is is is is is is is is is is is is is is is is is
Kindly suggest where is the flaw.
Thanks