This seems correct. One more thing to add, you can calculate loss only on the question: ...
part.
To do this set labels
to -100 for tokens before the question:
part, so cross entropy will ignore it.
Also you won’t need to explicitly set some arguments (position_ids
, head_mask
etc) to None
.
They are by default None
so it’s okay if don’t pass them. Will make the code more cleaner.