Correct way to get ResNet features from a pre-trained model?

Hi all,

I tried to load a pre-trained ResNetModel, however I’m getting the following weird exception:

Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: [‘classifier.1.bias’, ‘classifier.1.weight’]

  • This IS expected if you are initializing ResNetModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
  • This IS NOT expected if you are initializing ResNetModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

My code for loading the model is:
model = ResNetModel.from_pretrained(“microsoft/resnet-50”, force_download=True)

And I wanted to use the following code as some pre-trained features:
model(images)[‘pooler_output’] # images is a tensor of images of the correct shape

Is my code incorrect in some way?
From reading related questions I worry that somehow not all parts of the model were loaded correctly.

Thanks,
Tom

Hi,

If you’re doing:

from transformers import ResNetModel

model = ResNetModel.from_pretrained("microsoft/resnet-50")

Then it will load the base ResNet model without any head on top. Hence the warning: “Some weights of the model checkpoint at microsoft/resnet-50 were not used when initializing ResNetModel: [‘classifier.1.bias’, ‘classifier.1.weight’]”.

To load the classification head on top as well, you need to instantiate a ResNetForImageClassification, which adds a head on top of ResNetModel. In that case, there won’t be a warning.

However, as you want to use ResNet to get features (rather than performing image classification), it makes sense to load ResNetModel (as you don’t need the head). To get features, you can indeed do the following:

from transformers import ResNetModel
import torch

model = ResNetModel.from_pretrained(“microsoft/resnet-50”)

pixel_values = torch.randn(1, 3, 224, 224)

outputs = model(pixel_values)
features = outputs.pooler_output.squeeze()

This will give you a feature vector (in this case of length 2048) that you can use as a “representation” of the image. Note that in computer vision, one typically pulls features from different stages of the backbone, to get multi-scale features. This can be achieved by passing output_hidden_states=True, to get intermediate hidden states (feature maps):

from transformers import ResNetModel

import torch

model = ResNetModel.from_pretrained("microsoft/resnet-50")

pixel_values = torch.randn(1, 3, 224, 224)

outputs = model(pixel_values, output_hidden_states=True)

# one typically skips the initial embeddings
for i in outputs.hidden_states[1:]:
    print(i.shape)

Great, thank you @nielsr!

A followup question; indeed, I do try to learn representations so for now I only use the ResNetModel.
In your example above you assumed that the image values are normalized, but in my case I get an RGB image. From my understanding the correct way to use the model would be with a feature extractor:

feature_extractor = AutoFeatureExtractor.from_pretrained(“microsoft/resnet-50”)
image = … # rgb image
resnet_input = [feature_extractor(image).data[‘pixel_values’][0]]
representation_vec = model(resnet_input).pooler_output.squeeze()

The problem is that representation_vec keeps changing because feature_extractor applies random cropping (in my scenario I have a 224x224 image, so I want a deterministic representation vector).

I tried disabling cropping by setting
feature_extractor = AutoFeatureExtractor.from_pretrained(“microsoft/resnet-50”, do_resize=False)

But then I get inf values in representation_vec (before that the vector was inconsistent, but at least the values were not inf).

Any idea on how to make this part work?

Thanks,
Tom

Hi, normally the outputs should be deterministic as ResNet uses the same preprocessing pipeline as ConvNeXt. This is just resizing with a certain crop percentage + normalization.