So, I had a look at “examples/seq2seq”, but unfortunately I couldn’t extract much knowledge from the code. I think I need to practice my reading coding skills a bit more hehe.
I recently tried to use an EncoderDecoder Model
to recreate my encoder inputs on the decoder outputs, but I couldn’t make it work properly.
I tried to overfit a model using a set of 8 sentences to check if the model was working properly, but it just kept repeating the same token over and over.
Here is the code:
import transformers
encoder_config = transformers.BertConfig()
decoder_config = transformers.BertConfig()
config = transformers.EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
model = transformers.EncoderDecoderModel(config)
model = model.to(device)
model.train()
inputs = data['input_ids']
outputs = data['target']
outputs = inputs
inputs = inputs.to(device)
outputs = outputs.to(device)
optimizer = transformers.AdamW(model.parameters())
epochs = 50
for e in range(epochs):
optimizer.zero_grad()
loss, output = model(input_ids=inputs,
decoder_input_ids=outputs,
labels = outputs)[:2]
l_numerical = loss.item()
print(f"Epoch: {e}, Loss: {l_numerical}")
loss.backward()
optimizer.step()
I think it may be something related to the way I am feeding it inputs for the decoder, but I am not sure.
Here are my outputs and loss values. It looks like my loss is stuck:
Epoch: 0, Loss: 10.43316650390625
Epoch: 1, Loss: 9.597984313964844
Epoch: 2, Loss: 8.604254722595215
Epoch: 3, Loss: 7.791357040405273
Epoch: 4, Loss: 7.0663299560546875
Epoch: 5, Loss: 6.433999538421631
Epoch: 6, Loss: 5.967499256134033
Epoch: 7, Loss: 5.682089328765869
Epoch: 8, Loss: 5.520726203918457
Epoch: 9, Loss: 5.435981273651123
Epoch: 10, Loss: 5.400460720062256
Epoch: 11, Loss: 5.394018650054932
Epoch: 12, Loss: 5.399378776550293
Epoch: 13, Loss: 5.400329113006592
Epoch: 14, Loss: 5.3928022384643555
Epoch: 15, Loss: 5.382742404937744
Epoch: 16, Loss: 5.369752883911133
Epoch: 17, Loss: 5.36002254486084
Epoch: 18, Loss: 5.361392021179199
Epoch: 19, Loss: 5.364934921264648
Epoch: 20, Loss: 5.367682456970215
Epoch: 21, Loss: 5.367767810821533
Epoch: 22, Loss: 5.3658857345581055
Epoch: 23, Loss: 5.3620100021362305
Epoch: 24, Loss: 5.358331680297852
Epoch: 25, Loss: 5.357394218444824
Epoch: 26, Loss: 5.359387397766113
Epoch: 27, Loss: 5.359804153442383
Epoch: 28, Loss: 5.359673023223877
Epoch: 29, Loss: 5.359203338623047
Epoch: 30, Loss: 5.357675075531006
Epoch: 31, Loss: 5.356016635894775
Epoch: 32, Loss: 5.356872081756592
Epoch: 33, Loss: 5.357171535491943
Epoch: 34, Loss: 5.356071949005127
Epoch: 35, Loss: 5.355656623840332
Epoch: 36, Loss: 5.355906963348389
Epoch: 37, Loss: 5.356597900390625
Epoch: 38, Loss: 5.354917526245117
Epoch: 39, Loss: 5.355926990509033
Epoch: 40, Loss: 5.353925704956055
Epoch: 41, Loss: 5.35430908203125
Epoch: 42, Loss: 5.355685234069824
Epoch: 43, Loss: 5.355371475219727
Epoch: 44, Loss: 5.3547282218933105
Epoch: 45, Loss: 5.3535051345825195
Epoch: 46, Loss: 5.3542351722717285
Epoch: 47, Loss: 5.355259895324707
Epoch: 48, Loss: 5.3540568351745605
Epoch: 49, Loss: 5.353936672210693
tensor([[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996],
[ 101, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996,
1996, 1996, 1996, 1996, 1996, 1996, 1996, 1996]], device='cuda:0')
I am not using any input masks for both Encoder
and Decoder
because all sentences are complete. All of them use all spaces from the input array.
I also used BertTokenizer
on them so I think the input transformation is correct.
I hope this give you as much context as possible.
Thank you a lot! 