OOM error with standard NC24 ads A100 v4

I am trying to load, and fine tune a model based on the llama 7B base model; however, it seems as if no matter how large I make the cluster, I get an OOM error. In my latest iteration, I spun up a standard NC24 ads A100 v4 cluster on Azure which has 220 GB of memory. I am hesitant to simply keep creating a larger cluster since I am fairly certain this cluster ought to be large enough to load an fine tune the model, especially given that I have brought the batch size down to a single input string, and I still get an OOM error. I have been fighting this for a while and am frankly at a loss as to why I am still getting an OOM error even after beefing up my cluster to 220 GB of memory. The pertinent code that generates the OOM error is as follows:

import torch
import transformers
from transformers import AutoTokenizer
from transformers import pipeline

name = ‘FelixChao/vicuna-7B-chemical’
tokenizer = AutoTokenizer.from_pretrained(name, use_fast=True, padding_side=“left”)
tokenizer.pad_token = 2

config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
config.init_device = ‘cuda:0’
config.update({
“max_new_tokens”: 512,
“eos_token_id”: 2,
“pad_token_id”: 2,
“early_stopping”:False,
“num_beams”:1,
“torch_dtype”:torch.float16,
“trust_remote_code”:True
})

device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’)

model = transformers.AutoModelForCausalLM.from_pretrained(
name,
config=config,
trust_remote_code=True
).to(device)

encoded_tensors = tokenizer([‘This is an example input string’,
‘This is a second example input string’],
padding=True, return_tensors=“pt”)

from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=1e-5)

from transformers import get_scheduler
from torch.utils.data import DataLoader

num_epochs = 60
batch_size=1
example_loader = DataLoader(encoded_tensors[‘input_ids’],shuffle=False,batch_size=batch_size)
n_batches = len(example_loader)
num_training_steps = num_epochs * n_batches
lr_scheduler = get_scheduler(
name=“linear”, optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

criterion = torch.nn.CrossEntropyLoss()
from tqdm.auto import tqdm
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
model.train()
for batch in example_loader:
n_tokens = batch.shape[1]
outputs = model(batch.to(device))
loss = criterion(outputs[‘logits’][:,:-1,:].reshape((batch_size*(n_tokens-1),-1)).to(device),batch[:,1:].ravel().to(device))
loss.backward()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)

image