How to correct TypeError: zip argument #1 must support iteration training in multiple GPU

I am doing a creating custom pytorch layer and model training using Trainer API function on top of Hugging face model.

When I run on single GPU, it trains fine. But when I train it on multiple GPU it throws me error.

TypeError: zip argument #1 must support iteration training in multiple GPU

Training Code

bert_model = BertForTokenClassification.from_pretrained( model_checkpoint,id2label=id2label,label2id=label2id)
bert_model.config.output_hidden_states=True


class BERT_CUSTOM(nn.Module):
    
    
    def __init__(self, bert_model,id2label,num_labels):
        
        
        
        super(BERT_CUSTOM, self).__init__()
        self.bert = bert_model
        self.config=self.bert.config
        self.dropout = nn.Dropout(0.25)
        self.classifier = nn.Linear(768, num_labels)
        self.crf = CRF(num_labels, batch_first = True)
        
    
    def forward(self, input_ids, attention_mask,  labels=None, token_type_ids=None):
        
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = torch.stack((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4])).mean(dim=0)
        sequence_output = self.dropout(sequence_output)
        emission = self.classifier(sequence_output) # [32,256,21] logits
        
        if labels is not None:
            
            labels=labels.reshape(attention_mask.size()[0],attention_mask.size()[1])
            loss = -self.crf(log_soft(emission, 2), labels, mask=attention_mask.type(torch.uint8), reduction='mean')
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            return [loss, prediction]
                
        else:
            
            prediction = self.crf.decode(emission, mask=attention_mask.type(torch.uint8))
            prediction=[id2label[k] for k in prediction]
            return prediction

Training API

model = BERT_CUSTOM(bert_model, id2label,num_labels=len(label2id))
model.to(device)

args = TrainingArguments(
    "model",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=2,
    weight_decay=0.01,
    per_device_train_batch_size=32,
    fp16=True
    
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    tokenizer=tokenizer)

trainer.train()

Complete Traceback.

Traceback (most recent call last):
  File "spanbert_model_check.py", line 263, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1531, in train
    ignore_keys_for_eval=ignore_keys_for_eval,
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 1775, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 2523, in training_step
    loss = self.compute_loss(model, inputs)
  File "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py", line 2555, in compute_loss
    outputs = model(**inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 162, in forward
    return self.gather(outputs, self.output_device)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 174, in gather
    return gather(outputs, output_device, dim=self.dim)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 68, in gather
    res = gather_map(outputs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/parallel/scatter_gather.py", line 63, in gather_map
    return type(out)(map(gather_map, zip(*outputs)))
  [Previous line repeated 1 more time]
TypeError: zip argument #1 must support iteration