I have a method attached to my module (which was passed through the accelerator) that calls forward on a modified batch and does some processing. If I call this method, will it properly be called on different batches on different processes, like if I were to do model(**batch)
For more context: its during an evaluation loop. I want to compute a metric that passes non masked data, and gets hidden states. Hidden states from entire eval set are then considered as a dataset for clustering and metric computation.
Thanks for any help.