I am trying to fine-tune BERT for a multi-label classification task (Jigsaw toxic comments). I created a custom dataset and DataLoader as follows:
class CustomDataSet(Dataset):
def __init__(self,
features: np.ndarray,
labels: np.ndarray,
token_max: int,
tokenizer):
self.features = features
self.labels = labels
self.tokenizer = tokenizer
self.token_max = token_max
def __len__(self):
return len(self.features)
def __getitem__(self,
index: int):
comment_id, comment_text = self.features[index]
labels = self.labels[index]
encoding = self.tokenizer.encode_plus(
comment_text,
add_special_tokens=True,
max_length=self.token_max,
return_token_type_ids=False,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt')
return dict(
comment_text=comment_text,
comment_id=comment_id,
input_ids=encoding['input_ids'].squeeze(0),
attention_mask=encoding['attention_mask'].squeeze(0),
labels=torch.Tensor(labels))
and I use the following tokenizer:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
then I create my dataset using the custom class, and the corresponding dataloader:
train_dataset = CustomDataSet(X_train, y_train, tokenizer=tokenizer, token_max=256)
train_loader = DataLoader(
train_dataset, batch_size=32, shuffle=True, pin_memory=True, num_workers=16, persistent_workers=False
)
My model is the following:
class MultiLabelBERT(torch.nn.Module):
def __init__(self, num_labels):
super(MultiLabelBERT, self).__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased", torch_dtype=torch.float16, attn_implementation="sdpa")
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, num_labels)
self.classifier = self.classifier.to(torch.float16)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
logits = self.classifier(pooled_output)
return logits
first_BERT = MultiLabelBERT(6)
I first want to iterate through all batches of my training set and perform a forward pass:
for batch_idx, item in enumerate(train_loader):
input_ids = item['input_ids'].to(self.device)
attention_mask = item['attention_mask'].to(self.device)
labels = item['labels'].to(self.device)
logits = first_BERT(input_ids=input_ids,
attention_mask=attention_mask)
Even though I set num_workers=16 in my DataLoader, it only uses one CPU core to load data onto my GPU. This significantly slows down the process. Here’s what I’ve tried:
- Reducing the batch size.
- Reducing the number of tokens (token_max).
- Tokenizing the entire dataset beforehand to ensure the tokenizer isn’t causing the bottleneck issue. When I comment out the forward pass, the DataLoader uses all CPU workers as expected. However, with the forward pass, the process seems to be bottlenecked. I have the following gpu setup:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120 Driver Version: 550.120 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4060 ... Off | 00000000:01:00.0 Off | N/A |
| N/A 40C P0 588W / 115W | 9MiB / 8188MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1285 G /usr/lib/xorg/Xorg 4MiB |
Has anyone an idea of what could cause this?