OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU

I am experiencing cuda out of memory issue on my workstation (Ubuntu 20.04, 4xNvidia3090).

Below is my code:

import torch
from transformers import pipeline
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_openai import OpenAIEmbeddings
from langchain.vectorstores.chroma import Chroma
from langchain.prompts import ChatPromptTemplate
import os

# Set OpenAI API key
os.environ['OPENAI_API_KEY'] = 'openai_sfsadfafsafsfsfdsdfsadf'

CHROMA_PATH = "chroma"
DATA_PATH = "docs"

access_token = "hf_asfafasfsadfasdfasdfasdf"
login(token=access_token, add_to_git_credential=True)

PROMPT_TEMPLATE = """
Answer the question based only on the following context:
{context}
---
Answer the question based on the above context:
{question}
"""

def parallel_model(model, device_ids=None):
    """
    Sets up the model for parallel processing on the specified device IDs.

    Args:
        model (torch.nn.Module): The PyTorch model to be parallelized.
        device_ids (list, optional): List of device IDs to be used for parallel processing. If None, use all available GPUs.

    Returns:
        torch.nn.Module: The model set up for parallel processing.
    """
    if device_ids is None:
        device_ids = list(range(torch.cuda.device_count()))
    
    if len(device_ids) > 1:
        print(f"DataParallel on devices {device_ids} is used")
        model = torch.nn.DataParallel(model, device_ids=device_ids)
    
    device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    return model

# Function to query the RAG model and retrieve relevant context
def retrieve_context(query_text):
    # Prepare the DB
    embedding_function = OpenAIEmbeddings()
    db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function)

    # Search the DB
    results = db.similarity_search_with_relevance_scores(query_text, k=5)
    if len(results) == 0 or results[0][1] < 0.7:
        return None
    else:
        context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results])
        sources = [doc.metadata.get("source", None) for doc, _score in results]
        return context_text, sources

# Function to process input chunks on separate GPUs
def process_chunks_on_gpus(model, tokenizer, prompt, device_ids):
    """
    Processes the input prompt in chunks across multiple GPUs.

    Args:
        model (torch.nn.Module): The PyTorch model to be used.
        tokenizer (transformers.PreTrainedTokenizer): The tokenizer to be used.
        prompt (str): The input prompt to be processed.
        device_ids (list): List of device IDs to be used for parallel processing.

    Returns:
        str: The combined output from all GPUs.
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    chunk_size = len(inputs["input_ids"][0]) // len(device_ids)
    
    outputs = []
    for i, device_id in enumerate(device_ids):
        start_idx = i * chunk_size
        end_idx = start_idx + chunk_size
        input_chunk = {key: value[:, start_idx:end_idx].to(f"cuda:{device_id}") for key, value in inputs.items()}
        
        with torch.no_grad():
            output_chunk = model.generate(**input_chunk)
        
        outputs.append(tokenizer.decode(output_chunk[0], skip_special_tokens=True))

    return " ".join(outputs)

# Function to combine the RAG model with the foundation model
def generate_answer_with_foundation_model(query_text, context_text):
    # Load LLaMA3 model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

    device_ids = list(range(torch.cuda.device_count()))
    model = parallel_model(model, device_ids)
    
    # Prepare the input for the foundation model
    prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)
    prompt = prompt_template.format(context=context_text, question=query_text)
    
    # Process the input chunks on separate GPUs
    response_text = process_chunks_on_gpus(model, tokenizer, prompt, device_ids)

    return response_text

# Define the query text
query_text = "Explain oral gut axis"

# Retrieve context using the RAG model
context_result = retrieve_context(query_text)

if context_result:
    context_text, sources = context_result
    # Generate answer with the foundation model
    answer = generate_answer_with_foundation_model(query_text, context_text)
    formatted_response = f"Response: {answer}\nSources: {sources}"
    print(formatted_response)
else:
    print("Unable to find matching results.")

Error

---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
Cell In[3], line 104
    102 context_text, sources = context_result
    103 # Generate answer with the foundation model
--> 104 answer = generate_answer_with_foundation_model(query_text, context_text)
    105 formatted_response = f"Response: {answer}\nSources: {sources}"
    106 print(formatted_response)

Cell In[3], line 84, in generate_answer_with_foundation_model(query_text, context_text)
     81 model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
     83 device_ids = list(range(torch.cuda.device_count()))
---> 84 model = parallel_model(model, device_ids)
     86 # Prepare the input for the foundation model
     87 prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)

Cell In[3], line 28, in parallel_model(model, device_ids)
     25     model = torch.nn.DataParallel(model, device_ids=device_ids)
     27 device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
---> 28 model = model.to(device)
     30 return model

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:1173, in Module.to(self, *args, **kwargs)
   1170         else:
   1171             raise
-> 1173 return self._apply(convert)

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:779, in Module._apply(self, fn, recurse)
    777 if recurse:
    778     for module in self.children():
--> 779         module._apply(fn)
    781 def compute_should_use_set_data(tensor, tensor_applied):
    782     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    783         # If the new tensor has compatible tensor type as the existing tensor,
    784         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    789         # global flag to let the user control whether they want the future
    790         # behavior of overwriting the existing tensor or not.

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:779, in Module._apply(self, fn, recurse)
    777 if recurse:
    778     for module in self.children():
--> 779         module._apply(fn)
    781 def compute_should_use_set_data(tensor, tensor_applied):
    782     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    783         # If the new tensor has compatible tensor type as the existing tensor,
    784         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    789         # global flag to let the user control whether they want the future
    790         # behavior of overwriting the existing tensor or not.

    [... skipping similar frames: Module._apply at line 779 (3 times)]

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:779, in Module._apply(self, fn, recurse)
    777 if recurse:
    778     for module in self.children():
--> 779         module._apply(fn)
    781 def compute_should_use_set_data(tensor, tensor_applied):
    782     if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
    783         # If the new tensor has compatible tensor type as the existing tensor,
    784         # the current behavior is to change the tensor in-place using `.data =`,
   (...)
    789         # global flag to let the user control whether they want the future
    790         # behavior of overwriting the existing tensor or not.

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:804, in Module._apply(self, fn, recurse)
    800 # Tensors stored in modules are graph leaves, and we don't want to
    801 # track autograd history of `param_applied`, so we have to use
    802 # `with torch.no_grad():`
    803 with torch.no_grad():
--> 804     param_applied = fn(param)
    805 p_should_use_set_data = compute_should_use_set_data(param, param_applied)
    807 # subclasses may have multiple child tensors so we need to use swap_tensors

File ~/anaconda3/envs/tm/lib/python3.11/site-packages/torch/nn/modules/module.py:1159, in Module.to.<locals>.convert(t)
   1152     if convert_to_format is not None and t.dim() in (4, 5):
   1153         return t.to(
   1154             device,
   1155             dtype if t.is_floating_point() or t.is_complex() else None,
   1156             non_blocking,
   1157             memory_format=convert_to_format,
   1158         )
-> 1159     return t.to(
   1160         device,
   1161         dtype if t.is_floating_point() or t.is_complex() else None,
   1162         non_blocking,
   1163     )
   1164 except NotImplementedError as e:
   1165     if str(e) == "Cannot copy out of meta tensor; no data!":

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 

It sems the model parallel only assign the vRAM to GPU0, but not the other 3 GPUs.

Any idea how I can get this code to run properly?