Hmm… Like this?
# PatchTSTForClassification — minimal working example (PyTorch)
# Deps (pin or newer):
# transformers>=4.57.1 # docs: https://huggingface.co/docs/transformers/v4.57.1/model_doc/patchtst
# torch>=2.2 # install info: https://pytorch.org/get-started/locally/
#
# Key API references:
# - Forward args (past_values, target_values) and outputs (prediction_logits):
# https://huggingface.co/docs/transformers/v4.57.1/model_doc/patchtst#transformers.PatchTSTForClassification.forward
# - Config param is patch_stride (not "stride"):
# https://huggingface.co/docs/transformers/v4.57.1/model_doc/patchtst#transformers.PatchTSTConfig
# - Common pitfall: training loss requires target_values:
# https://discuss.huggingface.co/t/valueerror-when-using-patchtstforclassification/76082
# - Source code (class + output keys live here):
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/patchtst/modeling_patchtst.py
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PatchTSTConfig, PatchTSTForClassification
# ---- Device: CPU / CUDA safe ----
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(0)
# ---- Toy dataset: multivariate TS -> class id ----
class ToyTSCls(Dataset):
# Returns the exact keys HF expects:
# {"past_values": float32[B,L,C], "target_values": long[B]}
def __init__(self, n=1024, seq_len=128, channels=3, n_classes=4, seed=0):
g = torch.Generator().manual_seed(seed)
x = torch.randn(n, seq_len, channels, generator=g)
# Label signal from ch0: mean over last quarter, then bucketize into K bins
score = x[:, -seq_len // 4 :, 0].mean(dim=1)
qs = torch.linspace(0, 1, n_classes + 1)[1:-1]
bins = torch.quantile(score, qs)
y = torch.bucketize(score, bins)
self.x = x.float()
self.y = y.long()
def __len__(self): return self.x.size(0)
def __getitem__(self, i): return {"past_values": self.x[i], "target_values": self.y[i]}
# ---- Shapes / task ----
SEQ_LEN = 128 # context_length
C = 3 # num_input_channels
K = 4 # num_targets (classes)
BATCH = 32
# ---- Model config ----
# Use 'patch_stride' per docs; 'use_cls_token=True' for CLS pooling.
config = PatchTSTConfig(
num_input_channels=C,
num_targets=K,
context_length=SEQ_LEN,
patch_length=16,
patch_stride=8, # correct field per config docs
use_cls_token=True,
pooling_type="mean",
)
model = PatchTSTForClassification(config).to(device)
# ---- Data / Optimizer ----
train_ds = ToyTSCls(n=1024, seq_len=SEQ_LEN, channels=C, n_classes=K, seed=0)
train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, drop_last=True)
optim = torch.optim.AdamW(model.parameters(), lr=3e-4)
# ---- Train (fast on CPU) ----
model.train()
for epoch in range(2):
total_loss, seen = 0.0, 0
for batch in train_loader:
past = batch["past_values"].to(device) # float32 [B,L,C]
targets = batch["target_values"].to(device) # int64 [B]
# Pass target_values to receive .loss (see forum URL above)
out = model(past_values=past, target_values=targets)
loss = out.loss
optim.zero_grad(set_to_none=True)
loss.backward()
optim.step()
bs = past.size(0)
total_loss += loss.item() * bs
seen += bs
print(f"epoch {epoch} - loss {total_loss / seen:.4f}")
# ---- Inference (no labels) ----
model.eval()
with torch.no_grad():
x = torch.randn(4, SEQ_LEN, C, device=device)
out = model(past_values=x) # returns .prediction_logits
logits = out.prediction_logits # float32 [B,K]
preds = logits.argmax(dim=-1) # int64 [B]
print("logits shape:", tuple(logits.shape))
print("pred ids:", preds.cpu().tolist())
# Notes:
# - Inputs must be float32; labels are int64 class indices (not one-hot).
# - Ensure context_length == the sequence length you feed at runtime.
# - Tune patch_length/patch_stride for speed vs resolution.
"""
epoch 0 - loss 1.3933
epoch 1 - loss 1.3807
logits shape: (4, 4)
pred ids: [1, 0, 3, 2]
"""