Trying to train BERT using input_embeds as data fails on multiple GPUs

Trying to finetune/train a simple BERT model where we use the input_embeds parameter of the forward function exposed fails to train on multiple GPU’s. The same works on a single GPU. My understanding is the Trainer object should transparently handle single vs multiple gpu training.

Code to reproduce:

class MyLM(torch.nn.Module):
  def __init__(self, pt_model):
    super(MyLM, self).__init__()
    self.pt_model = pt_model
    self.register_buffer('graph', None)

  def forward(self, inputs, output_hidden_states=None): 
    pt_input_embs_table=pt_model.get_input_embeddings()
    input_embs = pt_input_embs_table(inputs['input_ids'])
    if ('labels' in inputs):
      pt_output=self.pt_model(inputs_embeds=input_embs, labels=inputs['labels'], output_hidden_states=output_hidden_states)
    else:
      pt_output=self.pt_model(input_ids=inputs['input_ids'], output_hidden_states=output_hidden_states)
    return pt_output



def train_model(model, data_collator, dataset, max_epochs=2, max_steps=-1):
  training_args = TrainingArguments(
      output_dir="./working_output",
      overwrite_output_dir=True,
      per_gpu_train_batch_size=BATCH_SIZE,
      num_train_epochs = max_epochs,
      save_steps=10_000,
      save_total_limit=2,
      prediction_loss_only = True,
      warmup_steps=500,                # number of warmup steps for learning rate scheduler
      weight_decay=0.01,# strength of weight decay
  )

  trainer = Trainer(
      model=model,
      args=training_args,
      data_collator=data_collator,
      train_dataset=dataset,
  )
  trainer.train()

pt_model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
my_model = MyLM(pt_model = pt_model)
my_model = torch.nn.DataParallel(my_model)
my_model.train()
train_model(my_model, data_collator, dataset, max_epochs=2)

The above fails with error:

untimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "<ipython-input-29-b9739f559081>", line 11, in forward
    input_embs = pt_input_embs_table(inputs['input_ids'])
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 160, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 2043, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking arugment for argument index in method wrapper_index_select)
​