TimeSeriesTransformer - mat1 and mat2 shapes cannot be multiplied

Hi, new to this HF model.

I am using some basic nondescript stock data consisting of three columns: timestamp, price, and volume. I create a time series dataset by using the past 20 timestamps to predict the next two. See the below images for the full code, but here is the error:

File ~/opt/miniforge3/envs/temp/lib/python3.11/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x22 and 20x64)

I happen to think it is related to context_length and lags_sequence since if i change those around, the same error with different inside numbers differing by 2 appear (e.g. 19 & 17 rather than what it is now: 22 & 20).

I am passing in data of the shape

torch.Size([128, 20])
torch.Size([128, 20, 2])
torch.Size([128, 2])
torch.Size([128, 2, 2])

which I believe is correct per the documentation for univariate time series. This model would benefit greatly from a medium article or further documentation on how to use it. But for today鈥檚 issue, I do not believe my code is incorrect as shown below.

predictions = model(
            past_values=batch['past_values'],
            past_time_features=batch['past_time_features'],
            past_observed_mask=None,
            future_values=batch['future_values'],
            future_time_features=batch['future_time_features'],
            
        )

I鈥檒l leave the rest of the photos in the comments.

Much love!

Environment:

  • transformers version: 4.35.2
  • Platform: macOS-13.5.1-arm64-arm-64bit
  • Python version: 3.11.5
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.0
  • Accelerate version: not installed
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Further notes:

Other transformer models with similar matrix issues: