Trying to finetune/train a simple BERT model where we use the input_embeds parameter of the forward function exposed fails to train on multiple GPU’s. The same works on a single GPU. My understanding is the Trainer object should transparently handle single vs multiple gpu training.
Code to reproduce:
class MyLM(torch.nn.Module):
def __init__(self, pt_model):
super(MyLM, self).__init__()
self.pt_model = pt_model
self.register_buffer('graph', None)
def forward(self, inputs, output_hidden_states=None):
pt_input_embs_table=pt_model.get_input_embeddings()
input_embs = pt_input_embs_table(inputs['input_ids'])
if ('labels' in inputs):
pt_output=self.pt_model(inputs_embeds=input_embs, labels=inputs['labels'], output_hidden_states=output_hidden_states)
else:
pt_output=self.pt_model(input_ids=inputs['input_ids'], output_hidden_states=output_hidden_states)
return pt_output
def train_model(model, data_collator, dataset, max_epochs=2, max_steps=-1):
training_args = TrainingArguments(
output_dir="./working_output",
overwrite_output_dir=True,
per_gpu_train_batch_size=BATCH_SIZE,
num_train_epochs = max_epochs,
save_steps=10_000,
save_total_limit=2,
prediction_loss_only = True,
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01,# strength of weight decay
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset,
)
trainer.train()
pt_model = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
my_model = MyLM(pt_model = pt_model)
my_model = torch.nn.DataParallel(my_model)
my_model.train()
train_model(my_model, data_collator, dataset, max_epochs=2)
The above fails with error:
untimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "<ipython-input-29-b9739f559081>", line 11, in forward
input_embs = pt_input_embs_table(inputs['input_ids'])
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 160, in forward
self.norm_type, self.scale_grad_by_freq, self.sparse)
File "/opt/conda/lib/python3.7/site-packages/torch/nn/functional.py", line 2043, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking arugment for argument index in method wrapper_index_select)
​