Plan
-
Parse CLI args (paths, shapes, channels, lr, epochs, etc.).
-
Dataset:
-
If --data
exists: load .npz
files: X: (T,H,W,C_in)
, Y: (H,W,C_out)
.
-
Else: synthesize toy data with known mapping (for sanity).
-
Return tensors as (C_in,T,H,W)
and (C_out,H,W)
.
-
Model:
-
ConvEncoder
: downsample + encode each frame.
-
ConvLSTM
: aggregate features across time.
-
ConvDecoder
: upsample to output C_out
channels.
-
Training:
-
Eval/infer:
- Quick PSNR/MAE; visualize grid (optional save).
-
Edge cases:
- Variable sequence lengths; mixed precision off switch; gradient clipping.
# file: src/video2image_main.py
import argparse
import math
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
# ---- Utils ------------------------------------------------------------------
def seed_everything(seed: int) -> None:
# Ensures reproducibility across loaders/devices
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
def psnr(x: torch.Tensor, y: torch.Tensor, data_range: float = 1.0) -> float:
mse = F.mse_loss(x, y).item()
return 99.0 if mse == 0 else 20.0 * math.log10(data_range) - 10.0 * math.log10(mse)
# ---- Data -------------------------------------------------------------------
class SequenceToImageNPZ(Dataset):
"""
Expects a directory of .npz files with keys:
- 'X': float32 array (T, H, W, C_in) # per-frame features
- 'Y': float32 array (H, W, C_out) # target image/features
Converts to tensors shaped (C_in, T, H, W) and (C_out, H, W).
"""
def __init__(self, root: Path):
self.files = sorted([Path(root) / f for f in os.listdir(root) if f.endswith(".npz")])
if not self.files:
raise FileNotFoundError(f"No .npz files found in {root}")
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, idx: int):
d = np.load(self.files[idx])
X = d["X"].astype(np.float32) # (T,H,W,C)
Y = d["Y"].astype(np.float32) # (H,W,C_out)
X = np.moveaxis(X, -1, 0) # (C_in,T,H,W) after next step
X = np.moveaxis(X, -1, 1) # (C_in,T,H,W)
Y = np.moveaxis(Y, -1, 0) # (C_out,H,W)
return torch.from_numpy(X), torch.from_numpy(Y)
class SyntheticSequenceToImage(Dataset):
"""
Generates sequences with (x,y, extra features) and a deterministic target.
Target combines smoothed temporal median + spatial transform to make learning feasible.
"""
def __init__(self, length: int, T: int, H: int, W: int, C_in: int, C_out: int):
self.n = length; self.T=T; self.H=H; self.W=W; self.C_in=C_in; self.C_out=C_out
# Precompute static coordinate grids
y, x = torch.meshgrid(torch.linspace(-1,1,H), torch.linspace(-1,1,W), indexing="ij")
self.xy = torch.stack([x, y], dim=0) # (2,H,W)
def __len__(self) -> int:
return self.n
def __getitem__(self, idx: int):
torch.manual_seed(idx) # deterministic per-sample
# base per-frame signal
frames = []
for t in range(self.T):
phase = 2 * math.pi * (t / self.T)
noise = 0.05 * torch.randn(self.C_in-2, self.H, self.W)
extra = torch.stack([
torch.sin(phase + self.xy[0]*math.pi),
torch.cos(phase + self.xy[1]*math.pi)
], dim=0)
# Build channels: xy + (rest)
rest = torch.cat([extra, noise], dim=0)
pad = max(0, (self.C_in - 2) - rest.shape[0])
if pad > 0:
rest = torch.cat([rest, torch.zeros(pad, self.H, self.W)], dim=0)
frame = torch.cat([self.xy, rest[:self.C_in-2]], dim=0) # (C_in,H,W)
frames.append(frame)
X = torch.stack(frames, dim=1) # (C_in,T,H,W)
# Target mixes temporal median + spatial warp
med = X.median(dim=1).values # (C_in,H,W)
filt = F.avg_pool2d(med, 3, 1, 1)
# Project to C_out with fixed weights to ensure signal exists
Wp = torch.randn(self.C_out, self.C_in, 1, 1) * 0.3
Y = F.conv2d(filt.unsqueeze(0), Wp).squeeze(0) # (C_out,H,W)
return X.float(), Y.float()
# ---- Model ------------------------------------------------------------------
class ConvLSTMCell(nn.Module):
"""Minimal ConvLSTM cell. Keeps parameters small for a simple baseline."""
def __init__(self, in_ch: int, hid_ch: int, k: int = 3):
super().__init__()
p = k // 2
self.conv = nn.Conv2d(in_ch + hid_ch, 4 * hid_ch, k, padding=p)
def forward(self, x, state):
h, c = state
y = self.conv(torch.cat([x, h], dim=1))
i, f, g, o = torch.chunk(y, 4, dim=1)
i = torch.sigmoid(i); f = torch.sigmoid(f); o = torch.sigmoid(o); g = torch.tanh(g)
c = f * c + i * g
h = o * torch.tanh(c)
return h, c
class ConvLSTM(nn.Module):
"""Single-layer ConvLSTM unrolled over time dimension."""
def __init__(self, in_ch: int, hid_ch: int):
super().__init__()
self.cell = ConvLSTMCell(in_ch, hid_ch)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
B, C, T, H, W = x.shape
h = x.new_zeros(B, self.cell.conv.out_channels // 4, H, W)
c = x.new_zeros(B, self.cell.conv.out_channels // 4, H, W)
for t in range(T):
h, c = self.cell(x[:, :, t], (h, c))
return h # (B, hidden, H, W)
class Encoder(nn.Module):
def __init__(self, c_in: int, c_mid: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(c_in, c_mid, 3, padding=1),
nn.GELU(),
nn.Conv2d(c_mid, c_mid, 3, padding=1),
nn.GELU(),
)
def forward(self, x):
return self.net(x)
class Decoder(nn.Module):
def __init__(self, c_in: int, c_out: int):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(c_in, c_in, 3, padding=1),
nn.GELU(),
nn.Conv2d(c_in, c_out, 1),
)
def forward(self, x):
return self.net(x)
class Video2ImageNet(nn.Module):
"""
Minimal encoder(2D) -> ConvLSTM (time) -> decoder(2D).
"""
def __init__(self, c_in: int, c_out: int, hidden: int = 64, enc: int = 64):
super().__init__()
self.encoder = Encoder(c_in, enc)
self.agg = ConvLSTM(enc, hidden)
self.decoder = Decoder(hidden, c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B,C,T,H,W)
B, C, T, H, W = x.shape
x = x.permute(0, 2, 1, 3, 4) # (B,T,C,H,W)
x = x.reshape(B*T, C, H, W) # merge time for 2D convs
x = self.encoder(x) # (B*T, enc, H, W)
x = x.view(B, T, -1, H, W).permute(0, 2, 1, 3, 4) # (B, enc, T, H, W)
h = self.agg(x) # (B, hidden, H, W)
y = self.decoder(h) # (B, C_out, H, W)
return y
# ---- Training ---------------------------------------------------------------
@dataclass
class TrainConfig:
data_dir: Optional[Path]
synth_len: int
time: int
height: int
width: int
c_in: int
c_out: int
batch: int
epochs: int
lr: float
weight_decay: float
loss: str
amp: bool
clip: float
seed: int
save_dir: Path
def make_dataloaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader]:
if cfg.data_dir:
ds = SequenceToImageNPZ(cfg.data_dir)
# infer shapes from first sample
x0, y0 = ds[0]
T = x0.shape[1]
assert T == cfg.time or cfg.time <= 0, f"Data T={T} differs from --time={cfg.time}"
else:
ds = SyntheticSequenceToImage(
length=cfg.synth_len, T=cfg.time, H=cfg.height, W=cfg.width, C_in=cfg.c_in, C_out=cfg.c_out
)
n_val = max(1, int(0.1 * len(ds)))
n_train = len(ds) - n_val
tr, va = random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(cfg.seed))
train_loader = DataLoader(tr, batch_size=cfg.batch, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(va, batch_size=cfg.batch, shuffle=False, num_workers=2, pin_memory=True)
return train_loader, val_loader
def get_loss(name: str):
name = name.lower()
if name == "l1":
return nn.L1Loss()
if name in ("mse", "l2"):
return nn.MSELoss()
if name in ("huber", "smoothl1"):
return nn.SmoothL1Loss(beta=0.01)
raise ValueError(f"Unknown loss {name}")
def train(cfg: TrainConfig) -> None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed_everything(cfg.seed)
train_loader, val_loader = make_dataloaders(cfg)
# Infer channels from data if not synthetic
if cfg.data_dir:
x0, y0 = next(iter(train_loader))
c_in = x0.shape[0+1] # (B,C,T,H,W) after batching; PyTorch loads (C,T,H,W)
c_out = y0.shape[1] # (B,C,H,W)
else:
c_in, c_out = cfg.c_in, cfg.c_out
model = Video2ImageNet(c_in=c_in, c_out=c_out, hidden=64, enc=64).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)
best_val = float("inf")
cfg.save_dir.mkdir(parents=True, exist_ok=True)
criterion = get_loss(cfg.loss)
for epoch in range(1, cfg.epochs + 1):
model.train()
total = 0.0
for X, Y in train_loader:
X = X.to(device) # (B,C,T,H,W)
Y = Y.to(device) # (B,C,H,W)
opt.zero_grad(set_to_none=True)
with torch.cuda.amp.autocast(enabled=cfg.amp):
pred = model(X)
loss = criterion(pred, Y)
scaler.scale(loss).backward()
if cfg.clip > 0:
scaler.unscale_(opt)
nn.utils.clip_grad_norm_(model.parameters(), cfg.clip) # keeps training stable
scaler.step(opt); scaler.update()
total += loss.item() * X.size(0)
sched.step()
train_loss = total / len(train_loader.dataset)
# Validation
model.eval()
vloss = 0.0; vpsnr = 0.0
with torch.no_grad():
for X, Y in val_loader:
X = X.to(device); Y = Y.to(device)
pred = model(X)
vloss += criterion(pred, Y).item() * X.size(0)
vpsnr += psnr(pred.clamp(0,1), Y.clamp(0,1)) * X.size(0)
vloss /= len(val_loader.dataset)
vpsnr /= len(val_loader.dataset)
print(f"epoch {epoch:03d} | train {train_loss:.4f} | val {vloss:.4f} | psnr {vpsnr:.2f}")
if vloss < best_val:
best_val = vloss
ckpt = cfg.save_dir / "best.pt"
torch.save({"model": model.state_dict(), "cfg": cfg.__dict__}, ckpt)
print(f"saved {ckpt} (val {vloss:.4f})")
# Export ONNX for downstream use
try:
model.eval()
X, Y = next(iter(val_loader))
dummy = X[:1].to(device)
onnx_path = cfg.save_dir / "video2image.onnx"
torch.onnx.export(
model, dummy, onnx_path.as_posix(),
input_names=["video"], output_names=["image"],
opset_version=17, dynamic_axes={"video": {0: "B", 2: "T"}, "image": {0: "B"}}
)
print(f"exported ONNX to {onnx_path}")
except Exception as e:
print(f"ONNX export skipped: {e}")
# ---- CLI --------------------------------------------------------------------
def parse_args() -> TrainConfig:
p = argparse.ArgumentParser(description="ConvLSTM baseline: sequence-of-feature-frames -> single image")
p.add_argument("--data", type=str, default="", help="dir of .npz files with X:(T,H,W,C_in), Y:(H,W,C_out)")
p.add_argument("--synth_len", type=int, default=1024, help="synthetic dataset size if no --data")
p.add_argument("--time", type=int, default=8, help="frames per sample (T)")
p.add_argument("--height", type=int, default=64)
p.add_argument("--width", type=int, default=64)
p.add_argument("--c_in", type=int, default=18, help="input channels (e.g., x,y + 16 features)")
p.add_argument("--c_out", type=int, default=3, help="output channels")
p.add_argument("--batch", type=int, default=8)
p.add_argument("--epochs", type=int, default=20)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--weight_decay", type=float, default=0.01)
p.add_argument("--loss", type=str, default="huber", choices=["l1","mse","huber"])
p.add_argument("--no_amp", action="store_true", help="disable mixed precision")
p.add_argument("--clip", type=float, default=1.0)
p.add_argument("--seed", type=int, default=42)
p.add_argument("--save_dir", type=str, default="checkpoints")
args = p.parse_args()
return TrainConfig(
data_dir=Path(args.data) if args.data else None,
synth_len=args.synth_len, time=args.time, height=args.height, width=args.width,
c_in=args.c_in, c_out=args.c_out,
batch=args.batch, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay,
loss=args.loss, amp=not args.no_amp, clip=args.clip, seed=args.seed,
save_dir=Path(args.save_dir),
)
if __name__ == "__main__":
cfg = parse_args()
train(cfg)
python src/video2image_main.py --data /path/to/npzs --time T --c_in N --c_out M
Response generated by TD Ai