Hi, I’m a 15 year old high school student. I’m following this Gradio Pictionary tutorial: Building A Pictionary App.
I tried running the code on Colab and ran into a problem on Step 1. I think the shape of model
and the pretrained checkpoints pytorch_model.bin
that I’m using do not match up.
The pytorch_model.bin
linked on the page is a git lfs file. To work around that, I obtained the file from nateraw/quickdraw at main, in particular, https://huggingface.co/nateraw/quickdraw/resolve/main/pytorch_model.bin
Can you please let me know where I can get the correct pytorch_model.bin
file or how I should change the model
architecture in the code below?
Below is the relevant portion of the code, as well as the error I’m getting.
!wget https://huggingface.co/nateraw/quickdraw/resolve/main/pytorch_model.bin
model = nn.Sequential(
nn.Conv2d(1, 32, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding='same'),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1152, 256),
nn.ReLU(),
nn.Linear(256, len(LABELS)),
)
state_dict = torch.load("pytorch_model.bin", map_location='cpu')
model.load_state_dict(state_dict, strict=False)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
[<ipython-input-57-192dd46058ba>](https://localhost:8080/#) in <module>()
18 )
19 state_dict = torch.load("pytorch_model.bin", map_location='cpu')
---> 20 model.load_state_dict(state_dict, strict=False)
21 model.eval()
[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
1603 if len(error_msgs) > 0:
1604 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1605 self.__class__.__name__, "\n\t".join(error_msgs)))
1606 return _IncompatibleKeys(missing_keys, unexpected_keys)
1607
RuntimeError: Error(s) in loading state_dict for Sequential:
size mismatch for 12.weight: copying a param with shape torch.Size([100, 256]) from checkpoint, the shape in current model is torch.Size([2, 256]).
size mismatch for 12.bias: copying a param with shape torch.Size([100]) from checkpoint, the shape in current model is torch.Size([2]).