here is the code i am using for the classification task:
I want to know if there is a way to speed up the process?
i am running the project in colab using A100 GPU
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import gc
import pandas as pd
import subprocess
import sys
import argparse
import re
from sklearn.metrics import classification_report, confusion_matrix
import transformers
Install xlsxwriter if not already installed
try:
import xlsxwriter
except ModuleNotFoundError:
print(âInstalling xlsxwriterâŚâ)
subprocess.check_call([sys.executable, â-mâ, âpipâ, âinstallâ, âxlsxwriterâ])
import xlsxwriter
Log versions of key dependencies
with open(âversion_log.txtâ, âwâ) as f:
f.write(f"Transformers version: {transformers.version}\n")
f.write(f"PyTorch version: {torch.version}\n")
f.write(f"CUDA version: {torch.version.cuda}\n")
print(âVersions logged to version_log.txtâ)
Argument parser for parameterization
parser = argparse.ArgumentParser(description=âInference script for vulnerability detectionâ)
parser.add_argument(ââmodel_nameâ, type=str, default=âdeepseek-ai/DeepSeek-R1-Distill-Qwen-14Bâ, help=âModel name or pathâ)
parser.add_argument(ââfunctions_fileâ, type=str, default=â/content/processed_dataset/test_functions_long.jsonâ, help=âPath to functions JSONâ)
parser.add_argument(ââlabels_fileâ, type=str, default=â/content/processed_dataset/test_labels_long.jsonâ, help=âPath to labels JSONâ)
parser.add_argument(ââoutput_fileâ, type=str, default=â/content/processed_dataset/results_long.xlsxâ, help=âPath to output Excel fileâ)
parser.add_argument(ââbatch_sizeâ, type=int, default=1, help=âBatch size for inferenceâ)
parser.add_argument(ââmax_lengthâ, type=int, default=16384, help=âMaximum sequence lengthâ)
parser.add_argument(âânum_workersâ, type=int, default=4, help=âNumber of DataLoader workersâ)
Remove Colab-specific arguments
if â-fâ in sys.argv:
f_index = sys.argv.index(â-fâ)
del sys.argv[f_index:f_index+2]
args = parser.parse_args()
Set device
device = torch.device(âcudaâ if torch.cuda.is_available() else âcpuâ)
print(f"Using device: {device}")
Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
tokenizer.pad_token = tokenizer.eos_token # Set pad token to eos token if not defined
print(âLoading model. This may take a few minutesâŚâ)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
torch_dtype=torch.float16, # Use fp16 to save memory
device_map=âautoâ # Automatically determine device mapping
)
model.to(device)
model.eval()
print(f"Model {args.model_name} loaded successfully.")
Clean memory function
def clean_memory():
print(â\nRunning garbage collectionâŚ\nâ)
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
print(f"GPU memory allocated: {torch.cuda.memory_allocated() / 10242:.2f} MB")
print(f"GPU memory reserved: {torch.cuda.memory_reserved() / 10242:.2f} MB")
Clean memory after loading
clean_memory()
Define dataset class with updated prompt
class InferenceDataset(Dataset):
def init(self, functions_data, tokenizer, max_length=args.max_length):
self.functions_data = functions_data
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.functions_data)
def __getitem__(self, idx):
item = self.functions_data[idx]
code = item['func']
idx = item['idx']
prompt = (
f"\nYou are a cybersecurity expert analyzing code for security vulnerabilities. "
f"Analyze the following code and determine if it contains security vulnerabilities. "
f"Begin your response with '<think>\n' and reason step by step, concisely but thoroughly, ending with '</think>'. "
f"After completing your analysis, end your response with exactly this format: "
f"[[**Prediction: yes**]] if the code is vulnerable, or "
f"[[**Prediction: no**]] if the code is not vulnerable. "
f"Do not add any extra text after the pattern. "
f"Code:\n{code}"
)
encodings = self.tokenizer(
prompt,
truncation=True, # Ensure truncation is enabled
max_length=self.max_length, # Set maximum sequence length
padding="max_length", # Pad sequences to max_length
return_tensors="pt"
)
return {
'input_ids': encodings['input_ids'].squeeze(0),
'attention_mask': encodings['attention_mask'].squeeze(0),
'idx': idx
}
Load data function
def load_data(functions_file, labels_file):
with open(functions_file, ârâ, encoding=âutf-8â) as f:
functions_data = json.load(f)
with open(labels_file, ârâ, encoding=âutf-8â) as f:
labels_data = json.load(f)
idx_to_target = {item[âidxâ]: item[âtargetâ] for item in labels_data}
return functions_data, labels_data, idx_to_target
Build dataset and dataloader with optimizations
def build_dataset_dataloader(functions_data, tokenizer):
dataset = InferenceDataset(functions_data, tokenizer, max_length=args.max_length)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4, # Safe default for Colab
prefetch_factor=2
)
return dataloader
Run inference with full text, thought process, and regex parsing
def run_inference(model, dataloader, tokenizer, device, idx_to_target):
predictions =
true_labels =
idx_list =
sample_outputs =
batch_idx = 0
with torch.no_grad():
for batch in tqdm(dataloader, desc=âProcessing batchesâ, unit=âbatchâ):
input_ids = batch[âinput_idsâ].to(device)
attention_mask = batch[âattention_maskâ].to(device)
idxs = batch[âidxâ].tolist()
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=2048,
temperature=0.6,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id
)
for i, output in enumerate(outputs):
generated_text = tokenizer.decode(output, skip_special_tokens=True)
# print(f"\nFull Generated Text for idx {idxs[i]}:\n{generated_text}\n{'-'50}")
# print(f"DEBUG: Raw text to parse:\n{repr(generated_text)}\n{'-'50}â) # Debug line
# Create a pattern that matches all desired formats
pattern = râ(?:[[**Prediction:\s(yes|no)**]]|****Prediction:\s(yes|no)****|**Prediction:\s*(yes|no)**)"
answer_match = re.search(pattern, generated_text, re.IGNORECASE)
if answer_match:
# Get the first non-None group (whichever pattern matched)
answer = next(group for group in answer_match.groups() if group is not None).lower()
pred = 1 if answer == âyesâ else 0
else:
pred = None
print(f"Warning: No prediction pattern found for idx {idxs[i]}â)
continue # Skip this sample
predictions.append(pred)
true_labels.append(idx_to_target[idxs[i]])
idx_list.append(idxs[i])
if len(sample_outputs) < 5:
sample_outputs.append({
âidxâ: idxs[i],
âgenerated_textâ: generated_text,
âpredictionâ: pred,
âtrue_labelâ: idx_to_target[idxs[i]]
})
clean_memory()
batch_idx += 1
if batch_idx % 10 == 0:
gpu_util = subprocess.check_output(ânvidia-smi --query-gpu=utilization.gpu --format=csv,noheaderâ, shell=True).decode().strip()
print(f"GPU Utilization at batch {batch_idx}: {gpu_util}â)
return predictions, true_labels, idx_list, sample_outputs
def compute_metrics(predictions, true_labels):
# Filter out None predictions
valid_indices = [i for i, p in enumerate(predictions) if p is not None]
valid_predictions = [predictions[i] for i in valid_indices]
valid_true_labels = [true_labels[i] for i in valid_indices]
if not valid_predictions:
print(âNo valid predictions to compute metrics.â)
return 0, 0, 0, 0, 0.0, 0.0, 0.0, 0.0
cm = confusion_matrix(valid_true_labels, valid_predictions)
tn, fp, fn, tp = cm.ravel()
accuracy = (tp + tn) / len(valid_predictions) if len(valid_predictions) > 0 else 0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
print(â\nClassification Report:â)
print(classification_report(valid_true_labels, valid_predictions, target_names=[âNon-Vulnerableâ, âVulnerableâ]))
return tp, fp, fn, tn, accuracy, precision, recall, f1
def write_outputs(idx_list, predictions, true_labels, tp, fp, fn, tn, accuracy, precision, recall, f1, output_file):
# Filter out None predictions for the DataFrame
valid_data = [(idx, pred, true) for idx, pred, true in zip(idx_list, predictions, true_labels) if pred is not None]
if not valid_data:
print(âNo valid data to write.â)
return
idx_list_valid, predictions_valid, true_labels_valid = zip(*valid_data)
results_df = pd.DataFrame({
âidxâ: idx_list_valid,
âPredictionâ: predictions_valid,
âTrue_Labelâ: true_labels_valid
})
metrics_df = pd.DataFrame({
âMetricâ: [âTrue Positives (TP)â, âTrue Negatives (TN)â, âFalse Positives (FP)â, âFalse Negatives (FN)â, âAccuracyâ, âPrecisionâ, âRecallâ, âF1-Scoreâ],
âValueâ: [tp, tn, fp, fn, accuracy, precision, recall, f1]
})
with pd.ExcelWriter(output_file, engine=âxlsxwriterâ) as writer:
results_df.to_excel(writer, sheet_name=âPredictionsâ, index=False)
metrics_df.to_excel(writer, sheet_name=âMetricsâ, index=False)
print(f"\nResults saved to {output_file}")
Main execution flow
if name == âmainâ:
# Load data
functions_data, labels_data, idx_to_target = load_data(args.functions_file, args.labels_file)
# Build dataset and dataloader
dataloader = build_dataset_dataloader(functions_data, tokenizer)
# Run inference
predictions, true_labels, idx_list, sample_outputs = run_inference(model, dataloader, tokenizer, device, idx_to_target)
# Compute metrics
tp, fp, fn, tn, accuracy, precision, recall, f1 = compute_metrics(predictions, true_labels)
# Write outputs
write_outputs(idx_list, predictions, true_labels, tp, fp, fn, tn, accuracy, precision, recall, f1, args.output_file)
# Clean memory one last time
clean_memory()