Hi, I am trying to read custom video datasets with their labels in a csv file and transfer train them using Huggingface Trainer with google/vit-base-patch16-224-in21k model. Here is what I have implemented so far but encountered a Key 42 error code.
import torch
import torchvision.transforms as transforms
import torchvision.io.video as video
from torch.utils.data import Dataset, DataLoader
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from sklearn.model_selection import train_test_split
import pandas as pd
import glob
import av
from tqdm import tqdm
# Define classes
classes = ['class_0', 'class_1']
# Define a custom dataset for video frames
class VideoDataset(Dataset):
def __init__(self, video_paths, labels, transform=None):
self.video_paths = video_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.video_paths)
def __getitem__(self, idx):
frames, _, _ = video.read_video(self.video_paths[idx], pts_unit='sec')
sample = {'frames': frames, 'label': self.labels[idx]}
if self.transform:
sample['frames'] = [self.transform(frame) for frame in sample['frames']]
return sample
# Define transformations for preprocessing frames
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Function to preprocess each frame
def preprocess_frame(frame):
frame = transforms.ToPILImage()(frame)
frame = transform(frame)
return frame
# Function to preprocess the dataset
def preprocess_dataset(dataset):
return dataset
# Define training arguments
training_args = TrainingArguments(
output_dir="./trained",
per_device_train_batch_size=4,
num_train_epochs=3,
logging_dir='./logs',
logging_steps=100,
)
# Load the ViT model and feature extractor
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
#Load training Videos
train_videos = glob.glob("./dataset100/train/all/*.mp4")
#Load Labels
train_df = pd.read_csv('./train100.csv',names=header_list)
labels = train_df['tag']
#Split dataset into train and validation sets
train_paths, val_paths,train_labels, val_labels= train_test_split(train_videos, labels, test_size=0.2, random_state=42)
# Create datasets and preprocess them
train_dataset = VideoDataset(train_paths, train_labels, transform=transform)
val_dataset = VideoDataset(val_paths, val_labels, transform=transform)
train_dataset = preprocess_dataset(train_dataset)
val_dataset = preprocess_dataset(val_dataset)
print("Number of training videos:", len(train_paths))
print("Number of validation videos:", len(val_paths))
print("Number of training labels:", len(train_labels))
print("Number of validation labels:", len(val_labels))
print("Training videos:", train_videos[:5])
print("Training labels:", train_labels[:5])
# Define function to compute accuracy
def compute_accuracy(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
return {"accuracy": (preds == labels).mean()}
# Define Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_accuracy,`your text`
)
# Train the model
trainer.train()
Here is the error message I received and I would really appreciate some help.
KeyError Traceback (most recent call last)
File c:\User\Code\venv\Lib\site-packages\pandas\core\indexes\base.py:3791, in Index.get_loc(self, key)
3790 try:
-> 3791 return self._engine.get_loc(casted_key)
3792 except KeyError as err:
File index.pyx:152, in pandas._libs.index.IndexEngine.get_loc()
File index.pyx:181, in pandas._libs.index.IndexEngine.get_loc()
File pandas\_libs\hashtable_class_helper.pxi:2606, in pandas._libs.hashtable.Int64HashTable.get_item()
File pandas\_libs\hashtable_class_helper.pxi:2630, in pandas._libs.hashtable.Int64HashTable.get_item()
KeyError: 42
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
Cell In[31], line 121
112 trainer = Trainer(
113 model=model,
114 args=training_args,
(...)
...
3801 # InvalidIndexError. Otherwise we fall through and re-raise
3802 # the TypeError.
3803 self._check_indexing_error(key)
KeyError: 42