PatchTSMixerForPrediction error with prediction of len 1

PatchTSMixerForPrediction with loss=‘nll’ and prediction_length=1 raise error because of

class StudentTOutput(DistributionOutput):
“”"
Student-T distribution output class.
“”"

args_dim: Dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
distribution_class: type = StudentT

@classmethod
def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
    scale = cls.squareplus(scale).clamp_min(torch.finfo(scale.dtype).eps)
    df = 2.0 + cls.squareplus(df)
    return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)

I don’t know what for squeeze is added, I have correct training without it with prediction_length>=1 but error with prediction_length=1.
Help me understand it or I need use regression for prediction_length=1?

so the squeeze is added to initialize a [Batch, prediction_lenght, 1] univariate distribution i believe.

so the failure of prediction_length = 1 is possibly an edge case that needs debugging…

can you cook up a small example to show the issue? Thanks!

import math
import torch
import transformers

import lightning as L
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader
from transformers import PatchTSMixerConfig, PatchTSMixerForPrediction

test_df = pd.DataFrame(np.arange(60).reshape(30, 2), columns=[‘labels’, ‘features’])

class PatchTSTDataset(Dataset):
r"""Transformer like dataset, time_varying_columns data in encoder,
time_constant_columns to decoder time T → T
label_columns data (ground truth) to decoder time T-1 → T
time step T knows ground truth values before T, for predicting we can’t use GT, need different tgr data (zeros)

Args:
    df: dataframe from which we create dataset
    input_width: timesteps length of past data window
    label_width: timesteps length of prediction
    past_columns: columns with past data
    past_observed_mask: Boolean mask to indicate which past_values were observed and which were missing.
    Mask values selected in [0, 1]:
    label_columns:
"""

def __init__(self,
             df: pd.DataFrame,
             input_width: int,
             label_width: int,
             past_columns: list[str],
             past_observed_mask: list[str],
             label_columns: list[str]):
    self.df = df
    self.input_width = input_width
    self.label_width = label_width

    self.past_columns = past_columns
    self.past_observed_mask = past_observed_mask
    self.label_columns = label_columns

    self.total_window_size = input_width + label_width
    self.input_slice = slice(0, input_width)
    self.label_start = self.total_window_size - self.label_width
    self.labels_slice = slice(self.label_start, None)
    self.decoder_slice = slice(self.label_start - 1, self.total_window_size - 1)
    self.__prepare_sequences()

def __len__(self):
    return len(self.sequences)

def __getitem__(self, index):
    sequence = self.df.loc[self.sequences[index], :]

    past_values = torch.from_numpy(sequence[self.past_columns].to_numpy(dtype=np.float32)[self.input_slice])
    past_observed_mask = self.past_observed_mask or torch.ones_like(past_values)

    values_dict = {
        'past_values': past_values,
        'observed_mask': past_observed_mask,
        'future_values':
            torch.from_numpy(sequence[self.past_columns].to_numpy(dtype=np.float32)[self.labels_slice])
    }
    return values_dict

def __prepare_sequences(self, mode='overlap'):
    self.sequences = [
        self.df.index[index * self.total_window_size: (index + 1) * self.total_window_size]
        for index in range(len(self.df) // self.total_window_size)
    ]

@property
def label_columns(self):
    return self._label_columns

@label_columns.setter
def label_columns(self, label_columns):
    self._label_columns = list(label_columns)
    self.prediction_channel_indices = [int(i) for i in self.df.columns.get_indexer(label_columns)]

INPUT_WIDTH = 5
OUTPUT_WIDTH = 1
batch_size=4

ds = PatchTSTDataset(df=test_df, input_width=INPUT_WIDTH, label_width=OUTPUT_WIDTH,
past_columns=test_df.columns, past_observed_mask=None, label_columns=[‘labels’, ‘features’])

model_config = PatchTSMixerConfig(
num_input_channels=len(ds.past_columns),
context_length=ds.input_width,
loss=‘nll’,
patch_len=3,
patch_stride=1,
prediction_length=ds.label_width,
d_model=2,
dropout=0.0,
expansion_factor=1,
num_layers=1,
scaling=None,
distribution_output=“normal”,
)

model = PatchTSMixerForPrediction(
config=model_config
)

train_loader = DataLoader(ds, batch_size=batch_size)
first_batch = next(iter(train_loader))

model(**first_batch)