OK So I tracked down the crash. The problem was the position embedding. I had max_seq_length == max_position_embeddings
, this results in a position index > max_position_embeddings for any sequence which is truncated.
This is because create_position_ids_from_input_ids
in modeling_roberta.py' below adds pdding_idx to the cumsum - if there are no masked
input_ids this will be >
max_seq_length`
mask = input_ids.ne(padding_idx).int()
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx