I have a very large (30gb) dataset with 700k+ images and 2780 classes, when running
dataset = load_dataset(“imagefolder”, data_dir=r"Desktop/Finished")
I have to wait for 9 hours for the dataset to fully load which I’d be okay with if it got the dataset right, but after multiple attempts and clearing the cache I get the same result of “438 classes found”, I’ve checked multiple times over my data and yes my images are all in the correct format.
Another thing is when i run
dataset = load_dataset(“imagefolder”, data_files=r"Desktop/Finished.zip")
It finds all the images and classes BUT when i go to train the AI it says it will take 60,000 hours to train and that, despite my batch size of 32, I am on step “2 / 160,875” after 15 minutes of waiting for the bar to show up.
when I trained my AI months ago using ViT it worked perfectly fine but my dataset was also loaded into .cache/huggingface/datasets/imagefolder by copying my entire dataset instead of loading them into arrow and parquet format. Is there a way to downgrade my environments to go back to this since it seemed to work much faster for me?
so far the only explanation I can come up with for this is that either MobileViT processing is making it slow or because I’m using a server that uses hard drives not a SSD.
here is my code:
import tensorflow as tf
tf.config.list_physical_devices('GPU')
model_checkpoint = "apple/mobilevitv2-1.0-imagenet1k-256" # pre-trained model from which to fine-tune
batch_size = 32 # batch size for training and evaluation
from huggingface_hub import notebook_login
notebook_login()
%%capture
!git config --global credential.helper store
from transformers.utils import send_example_telemetry
send_example_telemetry("image_classification_notebook", framework="pytorch")
from datasets import load_dataset
# dataset = load_dataset("imagefolder", data_files=r"Desktop/Finished.zip")
dataset = load_dataset("imagefolder", data_dir= r"Desktop/Finished(2)")
from datasets import load_metric
trust_remote_code=True
metric = load_metric("accuracy")
dataset
import PIL
example = dataset["train"][10]
example
dataset["train"].features
example['image']
example['image'].resize((200, 200))
example['label']
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
label2id[label] = i
id2label[i] = label
id2label[0]
from transformers import AutoImageProcessor
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor
pip install torchvision
import torchvision
torchvision.__version__
from torchvision.transforms import (
CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor,
)
# normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
size = (image_processor.size["height"], image_processor.size["width"])
crop_size = size
max_size = None
elif "shortest_edge" in image_processor.size:
size = image_processor.size["shortest_edge"]
crop_size = (size, size)
max_size = image_processor.size.get("longest_edge")
train_transforms = Compose(
[
RandomResizedCrop(crop_size),
RandomHorizontalFlip(),
ToTensor(),
# normalize,
]
)
val_transforms = Compose(
[
Resize(size),
CenterCrop(crop_size),
ToTensor(),
# normalize,
]
)
def preprocess_train(example_batch):
"""Apply train_transforms across a batch."""
example_batch["pixel_values"] = [
train_transforms(image.convert("RGB")) for image in example_batch["image"]
]
return example_batch
def preprocess_val(example_batch):
"""Apply val_transforms across a batch."""
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
return example_batch
# split up training into training + validation
splits = dataset["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
train_ds[0]
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from transformers import AutoImageProcessor, MobileViTV2ForImageClassification
model = MobileViTV2ForImageClassification.from_pretrained(
model_checkpoint,
label2id=label2id,
id2label=id2label,
ignore_mismatched_sizes = True,
).to(device)
! pip install -U accelerate
import accelerate
accelerate.__version__
pip install transformers[torch]
import transformers
transformers.__version__
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
f"{model_name}",
remove_unused_columns=False,
evaluation_strategy = "epoch",
save_strategy = "epoch",
learning_rate=1e-4,
per_device_train_batch_size=32,
gradient_accumulation_steps=2,
per_device_eval_batch_size=32,
gradient_checkpointing = True,
gradient_checkpointing_kwargs={'use_reentrant':False},
fp16=True,
num_train_epochs=30,
warmup_ratio=0.1,
weight_decay=-0.01,
logging_strategy = "steps",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
push_to_hub=False,
)
import numpy as np
def compute_metrics(eval_pred):
predictions = np.argmax(eval_pred.predictions, axis=1)
return metric.compute(predictions=predictions, references=eval_pred.label_ids)
import torch
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
labels = torch.tensor([example["label"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
trainer = Trainer(
model,
args,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=image_processor,
compute_metrics=compute_metrics,
data_collator=collate_fn,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
And here is my pip list
*absl-py 2.1.0
*accelerate 0.29.3
*aiohttp 3.9.5
*aiosignal 1.3.1
*asttokens 2.4.1
*astunparse 1.6.3
*async-timeout 4.0.3
*attrs 23.2.0
*cachetools 5.3.3
*certifi 2024.2.2
*charset-normalizer 3.3.2
*comm 0.2.2
*datasets 2.19.0
*debugpy 1.8.1
*decorator 5.1.1
*dill 0.3.8
*exceptiongroup 1.2.1
*executing 2.0.1
*filelock 3.13.4
*flatbuffers 24.3.25
*frozenlist 1.4.1
*fsspec 2023.9.2
*gast 0.4.0
*google-auth 2.29.0
*google-auth-oauthlib 1.0.0
*google-pasta 0.2.0
*grpcio 1.62.2
*h5py 3.11.0
*huggingface-hub 0.22.2
*idna 3.7
*importlib_metadata 7.1.0
*ipykernel 6.29.4
*ipython 8.18.1
*ipywidgets 8.1.2
*jax 0.4.26
*jedi 0.19.1
*Jinja2 3.1.3
*joblib 1.4.0
*jupyter_client 8.6.1
*jupyter_core 5.7.2
*jupyterlab_widgets 3.0.10
*keras 2.12.0
*libclang 18.1.1
*Markdown 3.68
*MarkupSafe 2.1.5
*matplotlib-inline 0.1.7
*ml-dtypes 0.4.0
*mpmath 1.3.0
*multidict 6.0.5
*multiprocess 0.70.16
*nest-asyncio 1.6.0
*networkx 3.2.1
*numpy 1.24.3
*nvidia-cublas-cu11 11.11.3.6
*nvidia-cublas-cu12 12.1.3.1
*nvidia-cuda-cupti-cu12 12.1.105
*nvidia-cuda-nvrtc-cu12 12.1.105
*nvidia-cuda-runtime-cu12 12.1.105
*nvidia-cudnn-cu11 8.6.0.163
*nvidia-cudnn-cu12 8.9.2.26
*nvidia-cufft-cu12 11.0.2.54
*nvidia-curand-cu12 10.3.2.106
*nvidia-cusolver-cu12 11.4.5.107
*nvidia-cusparse-cu12 12.1.0.106
*nvidia-nccl-cu12 2.20.5
*nvidia-nvjitlink-cu12 12.4.127
*nvidia-nvtx-cu12 12.1.105
*oauthlib 3.2.2
*opt-einsum 3.3.0
*packaging 24.0
*pandas 2.2.2
*parso 0.8.4
*pexpect 4.9.0
*pillow 10.3.0
*pip 24.0
*platformdirs 4.2.1
*prompt-toolkit 3.0.43
*protobuf 4.25.3
*psutil 5.9.8
*ptyprocess 0.7.0
*pure-eval 0.2.2
*pyarrow 16.0.0
*pyarrow-hotfix 0.6
*pyasn1 0.6.0
*pyasn1_modules 0.4.0
*Pygments 2.17.2
*python-dateutil 2.9.0.post0
*pytz 2024.1
*PyYAML 6.0.1
*pyzmq 26.0.2
*regex 2024.4.16
*requests 2.31.0
*requests-oauthlib 2.0.0
*rsa 4.9
*safetensors 0.4.3
*scikit-learn 1.4.2
*scipy 1.13.0
*setuptools 68.2.2
*six 1.16.0
*stack-data 0.6.3
*sympy 1.12
*tensorboard 2.12.3
*tensorboard-data-server 0.7.2
*tensorflow 2.12.1
*tensorflow-estimator 2.12.0
*tensorflow-io-gcs-filesystem 0.36.0
*termcolor 2.4.0
*threadpoolctl 3.4.0
*tokenizers 0.19.1
*torch 2.3.0
*torchvision 0.18.0
*tornado 6.4
*tqdm 4.66.2
*traitlets 5.14.3
*transformers 4.40.1
*triton 2.3.0
*typing_extensions 4.11.0
*tzdata 2024.1
*urllib3 2.2.1
*wcwidth 0.2.13
*Werkzeug 3.0.2
*wheel 0.41.2
*widgetsnbextension 4.0.10
*wrapt 1.14.1
*xxhash 3.4.1
*yarl 1.9.4
*zipp 3.18.1