Hey there, I need some help with the performance of my classifier.
I have trained a new model for text classification using the pre-trained BERT model (bert-base-uncased
). I created a class that handles the text transformations from raw to model-ready data. I am currently using the bert-base-uncased
tokenizer for processing the raw text. When running my code on a Jupyter Notebook or on my local machine, I can get predictions on new, unseen data in a matter of seconds. When I deploy my model to a Docker container it is extremely slow, to the point where the container times out even before producing the predictions. From debugging my code, I can see that it gets stuck in the tokenizer part. I have already tried batch encoding, single encoding, using the fast tokenizer from the AutoTokenizer library, and no matter what I do or change, the performance does not improve. I am not using GPUs, but I am trying to infer data from a relatively small dataset (approx. 100 data points). This is how I defined my class for transforming my raw dataset into a model-ready dataset:
class textDataset(Dataset):
# Constructor Function
def __init__(self, features, tokenizer, max_len):
self.features = features
self.tokenizer = tokenizer
self.max_len = max_len
# Length method
def __len__(self):
return len(self.features)
# get item method
def __getitem__(self, item):
feature = str(self.features[item])
# Encoded format to be returned
encoding = self.tokenizer.encode_plus(
feature,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
truncation=True,
return_tensors='pt',
)
return {
'feature': feature,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
}
Function for creating a PyTorch dataset:
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = textDataset(
features=df.feature.to_numpy(),
tokenizer=tokenizer,
max_len=max_len
)
return DataLoader(
ds,
batch_size=batch_size,
# num_workers=4
)
Tokenizers used (I tried a few to see if performance would improve):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncase")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncase", use_fast=True)
Class for model instance:
class textClassifier(nn.Module):
# Constructor class
def __init__(self, n_classes):
super(textClassifier, self).__init__()
self.bert = BertModel.from_pretrained("bert-base-uncased"), return_dict=False)
self.drop = nn.Dropout(p=0.3)
self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
# Forward propagation class
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(
input_ids=input_ids,
attention_mask=attention_mask
)
# Add a dropout layer
output = self.drop(pooled_output)
return self.out(output)
Function for predicting new values (which is where it gets stuck):
def get_predictions(model, data_loader):
model = model.eval()
feature_texts = []
predictions = []
prediction_probs = []
with torch.no_grad():
for d in data_loader:
texts = d["feature"]
input_ids = d["input_ids"]
attention_mask = d["attention_mask"]
# Get outputs
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
_, preds = torch.max(outputs, dim=1)
feature_texts.extend(texts)
predictions.extend(preds)
prediction_probs.extend(outputs)
predictions = torch.stack(predictions).cpu()
prediction_probs = torch.stack(prediction_probs).cpu()
return feature_texts, predictions, prediction_probs
And finally the main part where I call all the functions:
# Create data loader from processed df to use for model inference
# Instantiating PyTorch model
model = textClassifier(len(class_names))
batch_size = 16
prod_data_loader = create_data_loader(
features_transformed_df,
tokenizer,
max_len,
batch_size)
feature_texts, pred_ids, pred_probs = get_predictions(model, prod_data_loader)
I am relatively new to Transformers/PyTorch and I am having a hard time figuring out why my code is so slow when deployed to a container or an AWS lambda. Would anyone know the reason for this performance issue? Am I doing anything in the code above that is compromising my performance? Thanks a lot in advance.