I have been looking into the articles on the web, but unfortunately I cannot find the clear answer. I guess one of them is the correct decoder_input_ids
(label
should be decoder_input_ids[1:] ?):
1) <s>...</s><pad>...<pad>
2) </s><s>...</s><pad>...<pad>
3) <pad><pad>...</s><s>...</s>
Thanks in advance.
+) I am going to fine-tune this model for free form QA.
I guess decoder_input_ids
should be </s><s>...
(without </s>
), given label
as <s>...</s>
, according to the code below used for generating decoder_input_ids
:
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
assert pad_token_id is not None, "config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids