This seems correct. One more thing to add, you can calculate loss only on the question: ... part.
To do this set labels to -100 for tokens before the question: part, so cross entropy will ignore it.
Also you won’t need to explicitly set some arguments (position_ids, head_mask etc) to None.
They are by default None so it’s okay if don’t pass them. Will make the code more cleaner.