My guess would be that you have a specific sample in your dataset that is very long. Your collate function (not shown) might then be padding up to that length. That means that, for instance, your first <9k steps are of size 128x64 (seq_len x batch_size), which does not lead to an OOM. But then, around 9k steps you have a large sequence as a sample, which would (for instance) lead to 384 x 64 input, leading to an OOM.
So check the data distribution of your dataset, and check the collate function. You may want to specify a max_length that is smaller than model max length after all.