This trick of loading the model outside of _map_fn is awesome! It should save some memory. In pytorch-xla the model and the datset is loaded in all processes (8 in case 8 TPU cores) so it ends up taking lot of memory. Lazy loading dataset should also reduce RAM usage.
On V3-8, I was able to use bs of 8 per device with max_source_length 512 and max_target_length 64