Vision Transformer, can I put a multi-task classifier on it for fine-tuning?

I want to see if I can track a tennis ball: https://github.com/nyck33/tennis-tracking-vt/blob/master/tennis_ball_tracking_notebook_colab.ipynb

Like TrackNet did: https://nol.cs.nctu.edu.tw:234/open-source/TrackNet/
I am using their dataset.
But now I’m getting the error you can see under my training loop cell in the notebook.

So I asked ChatGPT about it and it said

'BaseModelOutputWithPooling' error in the context of multi-task classification with Vision Transformer are not widely documented.

Here are some general insights that might be helpful:

    Understanding 'BaseModelOutputWithPooling': This class is part of the Hugging Face Transformers library and is used to store the output of models that have a pooling layer. It contains attributes like last_hidden_state, pooler_output, etc. Understanding the structure of this object and how it interacts with your specific multi-task setup might help in diagnosing the issue.

    Multi-Task Learning with Transformers: Multi-task learning involves training a model on multiple tasks simultaneously. This can be achieved by having shared layers for common features and task-specific layers for each individual task. You might want to ensure that the architecture is correctly defined for multi-task learning, and the outputs are being handled appropriately.

    Customizing the Model: If the standard implementation is not suitable for your specific use case, you might need to extend or modify the existing classes in the Hugging Face library. This could involve writing custom forward methods, handling outputs differently, etc.
 

My classifier is in the notebook and looks like:

from transformers import ViTModel, ViTConfig
import torch.nn as nn

# Load the pre-trained Vision Transformer model
config = ViTConfig.from_pretrained('google/vit-base-patch16-224')
model = ViTModel(config)

# Modify the final layer
model.classifier = nn.Sequential(
    nn.Linear(config.hidden_size, 2),  # For the x, y coordinate prediction task
    nn.Linear(config.hidden_size, 3),  # For the event type prediction task
    nn.Softmax(dim=1)
)

# Move the model to the GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

The labels are x,y for the ball location which is a regression task I believe.
Then the classification task is for [visibility, ball state] where ball state is either 0=flying 1=being hit and 2=bouncing.
So that’s the multi-task. Can I do this?
What does this error mean:

Epochs:   0%|                                                                                        | 0/10 [00:00<?, ?it/s]
Training:   0%|                                                                                     | 0/992 [00:00<?, ?it/s]
Epochs:   0%|                                                                                        | 0/10 [00:10<?, ?it/s]

BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 1.4075e+00,  1.0877e+00,  2.0836e+00,  ..., -4.9522e-01,
           1.0107e+00, -1.4469e-01],
         [ 1.7730e+00,  7.3344e-01,  1.6785e+00,  ...,  4.3754e-01,
           1.0780e+00, -1.3120e+00],
         [ 1.6538e+00,  1.8222e+00,  1.3716e+00,  ...,  5.6994e-01,
           6.7216e-01, -1.5812e+00],
         ...,
         [ 7.8348e-01,  2.3157e+00,  1.1174e+00,  ..., -6.1076e-01,
          -3.3558e-01, -1.1977e+00],
         [ 8.3033e-01,  2.3129e+00,  1.1678e+00,  ..., -6.9030e-01,
          -3.5379e-01, -1.1818e+00],
         [-2.2939e-01,  8.8154e-01,  1.3024e+00,  ...,  1.4277e-01,
           2.5424e-01, -8.7588e-01]],

        [[ 1.3515e+00,  9.6144e-01,  2.1472e+00,  ..., -6.8164e-01,
           8.3853e-01, -1.3803e-01],
         [ 1.7667e+00,  7.7777e-01,  1.7557e+00,  ...,  3.9531e-01,
           1.0675e+00, -1.3184e+00],
         [ 1.4740e+00,  2.3893e+00,  2.2491e+00,  ...,  4.2060e-01,
           1.0265e+00, -1.5493e+00],
         ...,
         [ 7.5531e-01,  2.1728e+00,  1.1352e+00,  ..., -7.5911e-01,
          -4.0745e-01, -1.1917e+00],
         [ 7.8641e-01,  2.2141e+00,  1.1082e+00,  ..., -8.1729e-01,
          -4.8519e-01, -1.1400e+00],
         [-2.3452e-01,  8.0213e-01,  1.2662e+00,  ...,  9.7561e-02,
           1.4177e-01, -8.7745e-01]],

        [[ 3.8586e-01, -3.7611e-01,  1.2212e+00,  ..., -1.4200e+00,
           9.0314e-01,  4.7484e-01],
         [-1.3237e-01, -9.4974e-01, -5.3129e-01,  ..., -3.1903e-01,
           1.2864e-01,  8.8921e-01],
         [-3.1691e-02, -1.0052e+00, -5.5787e-01,  ..., -6.8185e-01,
           5.0634e-01,  8.7596e-01],
         ...,
         [ 5.9988e-01,  7.9034e-01,  5.2947e-01,  ..., -1.7098e+00,
          -4.7499e-02, -5.7214e-02],
         [ 5.2549e-01,  1.0163e+00,  2.4806e-01,  ..., -1.7150e+00,
          -5.3921e-02,  1.0030e-01],
         [ 3.9490e-01,  8.9926e-01,  2.6105e-01,  ..., -1.6806e+00,
          -1.5790e-02,  1.4663e-01]],

        ...,

        [[ 1.4630e+00,  1.1662e+00,  2.0751e+00,  ..., -4.8296e-01,
           8.5893e-01, -1.5953e-01],
         [ 1.8145e+00,  7.3506e-01,  1.5398e+00,  ...,  7.3405e-01,
           1.0155e+00, -1.3654e+00],
         [ 1.1271e+00,  1.7285e+00,  1.8843e+00,  ...,  6.7778e-01,
           4.2052e-01, -1.4642e+00],
         ...,
         [ 8.4874e-01,  2.3795e+00,  1.2374e+00,  ..., -5.3364e-01,
          -3.9910e-01, -1.1929e+00],
         [ 9.0117e-01,  2.4211e+00,  1.2559e+00,  ..., -6.5173e-01,
          -4.4803e-01, -1.1994e+00],
         [-1.8010e-01,  1.0535e+00,  1.3439e+00,  ...,  1.3942e-01,
           1.7736e-01, -8.6008e-01]],

        [[-8.3948e-01, -1.5321e+00,  1.5863e+00,  ...,  2.4904e-01,
          -1.0961e+00,  2.1653e-01],
         [-2.4010e-01, -2.4259e-01,  1.4192e+00,  ..., -1.6012e-01,
          -5.0620e-01, -6.1625e-01],
         [ 2.5693e-02,  2.8652e-03,  7.2747e-01,  ..., -4.8450e-01,
          -4.9017e-01,  1.4091e-01],
         ...,
         [-5.4195e-01, -2.4764e+00,  1.6138e+00,  ...,  1.9793e-01,
           7.5728e-01, -2.9917e-01],
         [-1.1824e+00, -2.1757e+00,  9.8805e-01,  ...,  4.4080e-02,
           3.6215e-01,  3.5327e-01],
         [-7.5140e-01, -3.1651e+00,  1.0936e+00,  ...,  1.6817e-01,
           9.7130e-01,  2.7627e-01]],

        [[ 4.0164e-01,  1.7907e+00,  1.9339e+00,  ..., -7.3154e-01,
           1.0858e+00, -6.5477e-01],
         [-1.2252e-01, -1.8509e-01,  2.5343e-01,  ...,  8.6780e-01,
           2.3686e-01, -1.7254e-01],
         [-7.9367e-01, -3.1540e-01,  8.8257e-01,  ..., -1.2943e-01,
           9.9491e-01,  3.4106e-02],
         ...,
         [ 1.7229e+00,  2.1307e+00,  2.1376e+00,  ..., -6.8334e-01,
          -4.9474e-02, -1.2210e+00],
         [ 1.6705e+00,  2.2838e+00,  2.1374e+00,  ..., -6.9727e-01,
          -1.1744e-01, -1.2008e+00],
         [ 1.7237e+00,  2.1265e+00,  2.0250e+00,  ..., -7.6357e-01,
          -1.3331e-01, -1.1458e+00]]], grad_fn=<NativeLayerNormBackward>), pooler_output=tensor([[-0.7080,  0.7782,  0.4632,  ...,  0.1265, -0.2022, -0.5439],
        [-0.7099,  0.7880,  0.5093,  ...,  0.1342, -0.1937, -0.5436],
        [-0.4683,  0.7500,  0.1708,  ..., -0.1012, -0.0416, -0.7161],
        ...,
        [-0.7174,  0.7723,  0.4723,  ...,  0.1362, -0.2242, -0.5252],
        [ 0.2980,  0.2066,  0.3803,  ...,  0.0875,  0.5324,  0.0552],
        [ 0.1257,  0.6983,  0.5794,  ...,  0.4954,  0.1278, -0.7507]],
       grad_fn=<TanhBackward>), hidden_states=None, attentions=None)


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[10], line 31
     29 else:
     30     print(outputs)
---> 31 loss = criterion(outputs, labels)
     33 # Backward pass and optimize
     34 loss.backward()

File ~/miniconda3/envs/vision_transformer/lib/python3.8/site-packages/torch/nn/modules/module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/vision_transformer/lib/python3.8/site-packages/torch/nn/modules/loss.py:528, in MSELoss.forward(self, input, target)
    527 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 528     return F.mse_loss(input, target, reduction=self.reduction)

File ~/miniconda3/envs/vision_transformer/lib/python3.8/site-packages/torch/nn/functional.py:3079, in mse_loss(input, target, size_average, reduce, reduction)
   3075 if has_torch_function_variadic(input, target):
   3076     return handle_torch_function(
   3077         mse_loss, (input, target), input, target, size_average=size_average, reduce=reduce, reduction=reduction
   3078     )
-> 3079 if not (target.size() == input.size()):
   3080     warnings.warn(
   3081         "Using a target size ({}) that is different to the input size ({}). "
   3082         "This will likely lead to incorrect results due to broadcasting. "
   3083         "Please ensure they have the same size.".format(target.size(), input.size()),
   3084         stacklevel=2,
   3085     )
   3086 if size_average is not None or reduce is not None:

AttributeError: 'BaseModelOutputWithPooling' object has no attribute 'size'