Mistral - Sentence classification - mat1 and mat2 shapes cannot be multiplied

Hi all,

While I already have a CamemBERT model running for a sentence classification task (mail classification), I am willing to see the results over the Mistral 7B model. As this model is quite different from BERT, it is relatively hard to use the “code template” used for BERT.

Hence, I am first trying to get a grasp of tutorials, such as the following. However, I see that it takes 10 hours to run it and there is no model.cuda() element in the code. But, when I add it, I get the following error:

/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py:429: UserWarning: torch.utils.checkpoint: please pass in use_reentrant=True or use_reentrant=False explicitly. The default value of use_reentrant will be updated to be False in the future. To maintain current behavior, pass use_reentrant=True. It is recommended that you use use_reentrant=False. Refer to docs for more details on the differences between the two variants.
  warnings.warn(
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-14-4c99aa47a15f> in <cell line: 10>()
      8 )
      9 
---> 10 trainer.train()

42 frames
/usr/local/lib/python3.10/dist-packages/bitsandbytes/autograd/_functions.py in forward(ctx, A, B, out, bias, quant_state)
    514         # 1. Dequantize
    515         # 2. MatmulnN
--> 516         output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
    517 
    518         # 3. Save state

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x4096 and 1x8388608)

Could anyone help me with this issue ? I would really appreciate it.

It must be stated that this error is also happening on other tutorials, or code snippets made by myself.

The code is run on GC, with the following environment:

- `transformers` version: 4.38.0.dev0
- Platform: Linux-6.1.58+-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.26.1
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.0+cu121 (True)
- Tensorflow version (GPU?): 2.15.0 (True)
- Flax version (CPU?/GPU?/TPU?): 0.8.0 (cpu)
- Jax version: 0.4.23
- JaxLib version: 0.4.23
- Using GPU in script?: YES
- Using distributed or parallel set-up in script?: NO

I already thank you for y’all answer !!