Unable to train Bert by splitting across GPUs

Hi,
The BERT model I am using has a very huge vocab size and as a result the model size is also very huge and I get OOM while training it. Now, I am trying to split my model across 2 GPUs such that the model.bert.embeddings part is on one GPU and model.bert.encoder and model.cls are on the other GPU. Below is the code I am using and the error I am getting

class bert_parall(torch.nn.Module):
    def __init__(self, model):
        super(bert_parall, self).__init__()
        # Embedding Layer --> cuda : 0
        self.embedding = model.bert.embeddings.to('cuda:1')
#         Encoder Layer --> cuda : 1
        self.encoder = model.bert.encoder.to('cuda:0') 
#         Classifer --> cuda : 1
        self.classifier = model.cls.to('cuda:0')
        
    def forward(self, input_ids, token_type_ids = None, attention_mask = None, labels = None):
        # Pass the input_ids to cuda:0 since embedding layer in cuda:0
        emb_out = self.embedding(input_ids.to('cuda:1')) 
        # Move the outputs of embedding layer to cuda:1 as input to encoder layer
        enc_out = self.encoder(emb_out.to('cuda:0'))
        classifier_out = self.classifier(enc_out[0]) 
        return classifier_out

def train_model(model, data_loader, loss_fn, optimizer):
    model = model.train() # Explicitly setting model to train state
    losses = []
    correct_predictions = 0
    
    for batch in data_loader:
        
        input_ids = apply_mask(torch.stack(batch.input_ids).t())
        attention_mask = torch.stack(batch.attention_mask).t()
        targets = batch.labels
        print("inputs ready")
        outputs= model(input_ids = input_ids, attention_mask = attention_mask)
        _, preds = torch.max(outputs, dim = 2)
        print("output")
        
        # Calculate loss by passing the targets to cuda:1
        loss = loss_fn(outputs, targets.to('cuda:1'))

        correct_predictions += torch.sum(preds == targets.to('cuda:1'))
        losses.append(loss.item())

        loss.backward()
        print("backprop")
        
        # Clip the gradients of the model to prevent exploding gradients using clip_grad_norm
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm = 1.0)
        optimizer.step()
        optimizer.zero_grad()
        

Training loop-

model_config= BertConfig(vocab_size= len(tokenizer.get_vocab()), hidden_size= 108, max_position_embeddings=params.max_len)
model= BertForMaskedLM(config= model_config)
  
bert= bert_parall(model)
optimizer = AdamW(bert.parameters())
# Move Loss function to cuda:1 since the roberta.classifier layer is in cuda:1
loss_fn = torch.nn.CrossEntropyLoss().to('cuda:1')
  
for epoch in range(epochs):
    print("epoch:", epoch)
    train_model(bert, loader, loss_fn, optimizer)

Error-

Traceback (most recent call last):
  File "#path hidden#", line 149, in <module>
    train_model(bert, loader, loss_fn, optimizer)
  File "#path hidden#", line 94, in train_model
    outputs= model(input_ids = input_ids, attention_mask = attention_mask)
  File "#path hidden#/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "#path hidden#", line 73, in forward
    emb_out = self.embedding(input_ids.to('cuda:1')) 
  File "#path hidden#/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "#path hidden#/lib/python3.7/site-packages/transformers/models/bert/modeling_bert.py", line 235, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "#path hidden#/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "#path hidden#/lib/python3.7/site-packages/torch/nn/modules/sparse.py", line 160, in forward
    self.norm_type, self.scale_grad_by_freq, self.sparse)
  File "#path hidden#/lib/python3.7/site-packages/torch/nn/functional.py", line 2183, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument index in method wrapper__index_select)