DP and DDP error with CLIP fine-tune

I want to train CLIP in parallelism, however, I tried DP and DDP, and both encountered some problems.

DP

Here is my code for model, I followed the huggingface’s implementation of CLIP

class CLIP(torch.nn.Module):
    def __init__(self, vision_model, text_model, vision_projection, text_projection, logit_scale):
        super(CLIP, self).__init__()

        self.vision_model = vision_model
        self.text_model = text_model
        self.vision_projection = vision_projection
        self.text_projection = text_projection
        self.logit_scale = logit_scale


    def forward(self, images, texts, return_loss=True):
        vision_output = self.vision_model(**images)[1]
        text_output = self.text_model(**texts)[1]
        image_embeds = self.vision_projection(vision_output)
        text_embeds = self.text_projection(text_output)
        
        # normalized features
        image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
        text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()
        logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
        logits_per_image = logits_per_text.t()

        loss = None
        if return_loss:
            loss = self.clip_loss(logits_per_text)
        
        return {
            'logits_per_text': logits_per_text,
            'logits_per_image': logits_per_image,
            'loss': loss,
            'vision_output': vision_output,
            'text_output': text_output,
            'image_embeds': image_embeds,
            'text_embeds': text_embeds
        }

and here is my main code:


def train_1epoch(dataloader, model, vision_processor, text_processor, optimizer, scheduler, device):
    model.train()
    total_loss = 0
    t = tqdm(dataloader)
    for i, (images, texts) in enumerate(t):
        images = vision_processor(images=images, return_tensors='pt')
        texts = text_processor(text=texts, padding='max_length', truncation=True, return_tensors='pt', max_length=77)
        images = images.to(device)
        texts = texts.to(device)
        optimizer.zero_grad()

        output = model(images, texts, return_loss=True)
        loss = output['loss']

        loss.backward()
        optimizer.step()
        scheduler.step()


def main():

    vision_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
    text_processor = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
    visual_projection = clip_model.visual_projection
    text_projection = clip_model.text_projection
    vision_model = clip_model.vision_model
    text_model = clip_model.text_model
    logit_scale = clip_model.logit_scale
    dataset = ImageTextDataset()
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=8)

    # fine-tune
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # device = 'cpu'
    model = CLIP(vision_model, text_model, visual_projection, text_projection, logit_scale)
    if torch.cuda.device_count() > 1 and device != 'cpu':
        print('Using ', torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-6, weight_decay=1e-4)

But I met this error:

Traceback (most recent call last):
  File "run_clip.py", line 86, in <module>
    main()
  File "run_clip.py", line 82, in main
    train_1epoch(dataloader, model, vision_processor, text_processor, optimizer, scheduler, device)
  File "run_clip.py", line 33, in train_1epoch
    output = model(images, texts, return_loss=True)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 168, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 178, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply
    output.reraise()
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/_utils.py", line 457, in reraise
    raise exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/workspace/GA3/utils/models.py", line 28, in forward
    vision_output = self.vision_model(**images)[1]
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/.local/transformers/models/clip/modeling_clip.py", line 843, in forward
    hidden_states = self.embeddings(pixel_values)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/work/.local/transformers/models/clip/modeling_clip.py", line 182, in forward
    patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 447, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/envs/ga3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper__cudnn_convolution)

DDP

For DDP, I tried accelerate, but I also met one error:
can't find dynamic library /home/opt/nvidia/lib64/libnvidia-ml original.so

Really appreciate it if anyone could help me with these problems.