TLDR: is the correct way to extract features from a HF ViT model outputs.pooler_output
or outputs.last_hidden_state[:, 0]
? where outputs is outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
.
Given only the ViT model it’s not clear what one should do to solve a vision
classification problem. I eventually converged to this answer (but I am unsure if it is correct or the best anymore, will provide full code at the end):
outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
output: Tensor = self.dropout(outputs.last_hidden_state[:, 0])
logits: Tensor = self.cls(output)
Intuitively it makes sense, we can to extract the features from the cls token position. However, once I printed all the layers for the ViTModel I would have personally chosen a different layer because it’s right before the cls layer AND because printing the activations seem to be in a better range to be honest. I would have chosen the ones in the (pooler): ViTPooler(...)
layer, right after the Tanh()
. Doing that results in this:
outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
outputs.pooler_output
tensor([[-0.3976, -0.8454, -0.0601, ..., -0.2804, -0.1822, 0.1917],
[-0.3392, -0.0248, 0.1346, ..., -0.5822, 0.8779, 0.4147],
[-0.2980, -0.8038, -0.1146, ..., 0.2431, -0.0963, 0.7844],
...,
[-0.1237, -0.7514, 0.7388, ..., -0.8551, 0.1512, 0.6157],
[ 0.5351, -0.9040, 0.0387, ..., -0.0773, 0.2704, -0.0311],
[ 0.2142, -0.3138, 0.0426, ..., -0.5943, 0.2873, 0.4420]],
grad_fn=<TanhBackward>)
outputs.last_hidden_state[:, 0]
tensor([[ 5.7313e-01, -2.1335e+00, 2.0491e-01, ..., -1.2373e-01,
-2.0056e-01, -4.8167e-01],
[ 5.3309e-02, -1.6563e+00, 1.5719e+00, ..., -1.3617e+00,
-3.0064e-01, -2.0056e-01],
[-2.0633e-02, -2.1370e+00, 9.9927e-01, ..., -2.3584e+00,
8.6123e-01, -1.2759e+00],
...,
[ 3.9583e-01, -1.3500e+00, 1.7638e+00, ..., -9.9536e-01,
1.0843e+00, -4.4368e-01],
[ 1.6026e+00, -6.4654e-01, 2.4882e+00, ..., -1.0347e+00,
-1.3160e-03, -2.4357e+00],
[-1.2769e-02, -9.6574e-01, 1.6432e+00, ..., -7.9090e-01,
6.1669e-01, 3.2990e-01]], grad_fn=<SelectBackward>)
and sums for sanity checks
outputs.pooler_output.sum()
tensor(3.8430, grad_fn=<SumBackward0>)
outputs.last_hidden_state[:, 0].sum()
tensor(-6.4373e-06, grad_fn=<SumBackward0>)
and shapes
outputs.pooler_output.shape
torch.Size([25, 768])
outputs.last_hidden_state[:, 0].shape
torch.Size([25, 768])
which for outputs.pooler_output.shape
look much better behaves. But my forward pass uses outputs.last_hidden_state[:, 0]
for some reason.
Which one should I have used?
Full code:
class ViTForImageClassificationUU(nn.Module):
def __init__(self,
num_classes: int,
image_size: int, # 224 inet, 32 cifar, 84 mi, 28 mnist, omni...
criterion: Optional[Union[None, Callable]] = None,
# Note: USL agent does criterion not model usually for me e.g nn.Criterion()
cls_p_dropout: float = 0.0,
pretrained_name: str = None,
vitconfig: ViTConfig = None,
):
"""
:param num_classes:
:param pretrained_name: 'google/vit-base-patch16-224-in21k' # what the diff with this one: "google/vit-base-patch16-224"
"""
super().__init__()
if vitconfig is not None:
raise NotImplementedError
self.vitconfig = vitconfig
print(f'You gave a config so everyone other param given is going to be ignored.')
elif pretrained_name is not None:
raise NotImplementedError
# self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
self.model = ViTModel.from_pretrained(pretrained_name)
print('Make sure you did not give a vitconfig or this pretrained name will be ignored.')
else:
self.num_classes = num_classes
self.image_size = image_size
self.vitconfig = ViTConfig(image_size=self.image_size)
self.model = ViTModel(self.vitconfig)
assert cls_p_dropout == 0.0, 'Error, for now only p dropout for cls is zero until we figure out if we need to ' \
'change all the other p dropout layers too.'
self.dropout = nn.Dropout(cls_p_dropout)
self.cls = nn.Linear(self.model.config.hidden_size, num_classes)
self.criterion = None if criterion is None else criterion
def forward(self, batch_xs: Tensor, labels: Tensor = None) -> Tensor:
"""
Forward pass of vit. I added the "missing" cls (and dropout layer before it) to act on the first cls
token embedding. Remaining token embeddings are ignored/not used.
I think the feature extractor only normalizes the data for you, doesn't seem to even make it into a seq, see:
...
so idk why it's needed but an example using it can be found here:
- colab https://colab.research.google.com/drive/1Z1lbR_oTSaeodv9tTm11uEhOjhkUx1L4?usp=sharing#scrollTo=cGDrb1Q4ToLN
- blog with trainer https://huggingface.co/blog/fine-tune-vit
- single PIL notebook https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Quick_demo_of_HuggingFace_version_of_Vision_Transformer_inference.ipynb
"""
outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
output: Tensor = self.dropout(outputs.last_hidden_state[:, 0])
logits: Tensor = self.cls(output)
if labels is None:
assert logits.dtype == torch.float32
return logits # this is what my usl agent does ;)
else:
raise NotImplementedError
assert labels.dtype == torch.long
# loss = self.criterion(logits.view(-1, self.num_classes), labels.view(-1))
loss = self.criterion(logits, labels)
return loss, logits
def get_embedding(self, batch_xs: Tensor) -> Tensor:
"""
Get the feature embedding of the first cls token.
Details:
By observing the ViTLayer, the (pooler) ViTPoooler(...) has an activation and a Tanh() layer.
From playing around
<TanhBackward>, so it seems that it the right one. Plus, printing
outputs.pooler_output.sum()
tensor(3.8430, grad_fn=<SumBackward0>)
looks more sensible than trying to get the features for the cls position manually:
outputs.last_hidden_state[:, 0, :].sum()
tensor(-6.4373e-06, grad_fn=<SumBackward0>)
which looked weird.
"""
# outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
outputs: BaseModelOutputWithPooling = self.model(pixel_values=batch_xs)
feat = outputs.pooler_output
# out = model.model(x)
# hidden_states = out.last_hidden_state
# # Get the CLS token's features (position 0)
# cls_features = hidden_states[:, 0]
# return out
# Obtain the outputs from the base ViT model
# outputs = self.model(pixel_values, *args, **kwargs)
# pooled_output = outputs.pooler_output
# image_representation = outputs.last_hidden_state[:, 0, :]
return feat
def _assert_its_random_model(self):
from uutils.torch_uu import norm
pre_trained_model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')
print(f'----> {norm(pre_trained_model)=}')
print(f'----> {norm(self)=}')
assert norm(pre_trained_model) > norm(self), f'Random models usually have smaller weight size but got ' \
f'{norm(pre_trained_model)}{norm(self)}'
def get_vit_get_vit_model_and_model_hps(vitconfig: ViTConfig = None,
num_classes: int = 5,
image_size: int = 84, # 224 inet, 32 cifar, 84 mi, 28 mnist, omni...
criterion: Optional[Union[None, Callable]] = None, # for me agent does it
cls_p_dropout: float = 0.0,
pretrained_name: str = None,
) -> tuple[nn.Module, dict]:
"""get vit for mi, only num_classes = 5 and image size 84 is needed. """
model_hps: dict = dict(vitconfig=vitconfig,
num_classes=num_classes,
image_size=image_size,
criterion=criterion,
cls_p_dropout=cls_p_dropout,
pretrained_name=pretrained_name)
model: nn.Module = ViTForImageClassificationUU(**model_hps)
print('Its recommended to set args.allow_unused = True for ViT models.')
return model, model_hps
def vit_forward_pass():
# - for determinism
import random
import numpy as np
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
# - options for number of tasks/meta-batch size
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
# - get my vit model
vitconfig: ViTConfig = ViTConfig()
# model = ViTForImageClassificationUU(num_classes=64 + 1100, image_size=84)
model = get_vit_get_vit_model_and_model_hps(vitconfig, num_classes=64 + 1100, image_size=84)
criterion = nn.CrossEntropyLoss()
# to device
model.to(device)
criterion.to(device)
# - forward pass
x = torch.rand(5, 3, 84, 84)
y = torch.randint(0, 64 + 1100, (5,))
logits = model(x)
loss = criterion(logits, y)
print(f'{loss=}')