@sgugger Thanks for the amazing course. I was trying to use DataCollatorWithPadding in the following code but wanted to check if i am on the right path?
!pip -q install transformers datasets accelerate sentence-transformers iterative-stratification umap-learn wandb hdbscan altair altair-data-server
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook
import os, gc, shutil, re, warnings
import pickle
warnings.filterwarnings("ignore")
# set the max columns to none
pd.set_option('display.max_columns', None)
import random
SEED=75
random.seed(SEED)
import joblib
from sklearn.manifold import TSNE
from umap import UMAP
from torch.nn.functional import normalize
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
import torch
import torch.nn as nn
import transformers
from transformers import (
AutoModel, AutoConfig,
AutoTokenizer, logging,
AdamW, get_linear_schedule_with_warmup,
DataCollatorWithPadding,
Trainer, TrainingArguments
)
from transformers.modeling_outputs import SequenceClassifierOutput
logging.set_verbosity_error()
logging.set_verbosity_warning()
import wandb
#### plots
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors
from IPython.core.display import display, HTML
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import init_notebook_mode
### Plotly settings
temp=dict(layout=go.Layout(font=dict(family="Ubuntu", size=14),
height=600,
plot_bgcolor = "#ededed",
paper_bgcolor = "#ededed"))
# Load data from huggingface
from datasets import load_dataset
dataset = load_dataset("cdsi-nlp-workshops/arxiv_classification")
dataset
raw_train_dataset = dataset["train"]
text_list = raw_train_dataset['text']
raw_train_dataset.features
import re
import string
def preprocess_text(text):
# Convert to lowercase
text = text.lower()
# Remove unwanted characters
text = re.sub(r"\n", " ", text) # Replace newline characters with space
text = re.sub(r"\s+", " ", text) # Replace multiple spaces with a single space
text = text.strip() # Remove leading/trailing whitespaces
# Remove "abstract" part
text = re.sub(r"^abstract\s+", "", text)
# Remove punctuation
text = text.translate(str.maketrans("", "", string.punctuation))
return text
# Clean each "text" string
sentences = list(map(lambda text: preprocess_text(text), text_list))
print("Cleaned text:")
print(sentences)
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
# Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
# Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
# Reduce batch size and sequence length if necessary
max_batch_size = 8
max_sequence_length = 128
num_sentences = len(sentences)
if num_sentences > max_batch_size:
encoded_input = {key: value[:max_batch_size] for key, value in encoded_input.items()}
if encoded_input['input_ids'].size(1) > max_sequence_length:
encoded_input = {key: value[:, :max_sequence_length] for key, value in encoded_input.items()}
# Compute token embeddings
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
for key, value in encoded_input.items():
encoded_input[key] = value.to(device)
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
# Normalize embeddings
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
print("Sentence embeddings:")
print(sentence_embeddings)
Please advise on how i can incorporate within my code.
Looking forward to hearing from you
Thanks,
Andy