Sample python code for generative (almost) video neural network

I am looking for simply python code to train generative video neural network. I am planning to feed in sequence of images and would like to get one image back. Each point in the image will have x and y coordinates and 15-20 other features associated with each point (this is slightly different from traditional video image where each point has 3 colors associated with it).

I am searching the web for examples of it, but so far only finding fairly complicated cases.

1 Like

Some partially reusable code seems to be available online…?

1 Like

Plan

  • Parse CLI args (paths, shapes, channels, lr, epochs, etc.).

  • Dataset:

    • If --data exists: load .npz files: X: (T,H,W,C_in), Y: (H,W,C_out).

    • Else: synthesize toy data with known mapping (for sanity).

    • Return tensors as (C_in,T,H,W) and (C_out,H,W).

  • Model:

    • ConvEncoder: downsample + encode each frame.

    • ConvLSTM: aggregate features across time.

    • ConvDecoder: upsample to output C_out channels.

  • Training:

    • MSE/L1/Huber selectable; AdamW; cosine schedule; AMP.

    • Save best checkpoint, export ONNX.

  • Eval/infer:

    • Quick PSNR/MAE; visualize grid (optional save).
  • Edge cases:

    • Variable sequence lengths; mixed precision off switch; gradient clipping.
# file: src/video2image_main.py
import argparse
import math
import os
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

# ---- Utils ------------------------------------------------------------------

def seed_everything(seed: int) -> None:
    # Ensures reproducibility across loaders/devices
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

def psnr(x: torch.Tensor, y: torch.Tensor, data_range: float = 1.0) -> float:
    mse = F.mse_loss(x, y).item()
    return 99.0 if mse == 0 else 20.0 * math.log10(data_range) - 10.0 * math.log10(mse)

# ---- Data -------------------------------------------------------------------

class SequenceToImageNPZ(Dataset):
    """
    Expects a directory of .npz files with keys:
      - 'X': float32 array (T, H, W, C_in)  # per-frame features
      - 'Y': float32 array (H, W, C_out)    # target image/features
    Converts to tensors shaped (C_in, T, H, W) and (C_out, H, W).
    """
    def __init__(self, root: Path):
        self.files = sorted([Path(root) / f for f in os.listdir(root) if f.endswith(".npz")])
        if not self.files:
            raise FileNotFoundError(f"No .npz files found in {root}")

    def __len__(self) -> int:
        return len(self.files)

    def __getitem__(self, idx: int):
        d = np.load(self.files[idx])
        X = d["X"].astype(np.float32)  # (T,H,W,C)
        Y = d["Y"].astype(np.float32)  # (H,W,C_out)
        X = np.moveaxis(X, -1, 0)      # (C_in,T,H,W) after next step
        X = np.moveaxis(X, -1, 1)      # (C_in,T,H,W)
        Y = np.moveaxis(Y, -1, 0)      # (C_out,H,W)
        return torch.from_numpy(X), torch.from_numpy(Y)

class SyntheticSequenceToImage(Dataset):
    """
    Generates sequences with (x,y, extra features) and a deterministic target.
    Target combines smoothed temporal median + spatial transform to make learning feasible.
    """
    def __init__(self, length: int, T: int, H: int, W: int, C_in: int, C_out: int):
        self.n = length; self.T=T; self.H=H; self.W=W; self.C_in=C_in; self.C_out=C_out
        # Precompute static coordinate grids
        y, x = torch.meshgrid(torch.linspace(-1,1,H), torch.linspace(-1,1,W), indexing="ij")
        self.xy = torch.stack([x, y], dim=0)  # (2,H,W)

    def __len__(self) -> int:
        return self.n

    def __getitem__(self, idx: int):
        torch.manual_seed(idx)  # deterministic per-sample
        # base per-frame signal
        frames = []
        for t in range(self.T):
            phase = 2 * math.pi * (t / self.T)
            noise = 0.05 * torch.randn(self.C_in-2, self.H, self.W)
            extra = torch.stack([
                torch.sin(phase + self.xy[0]*math.pi),
                torch.cos(phase + self.xy[1]*math.pi)
            ], dim=0)
            # Build channels: xy + (rest)
            rest = torch.cat([extra, noise], dim=0)
            pad = max(0, (self.C_in - 2) - rest.shape[0])
            if pad > 0:
                rest = torch.cat([rest, torch.zeros(pad, self.H, self.W)], dim=0)
            frame = torch.cat([self.xy, rest[:self.C_in-2]], dim=0)  # (C_in,H,W)
            frames.append(frame)
        X = torch.stack(frames, dim=1)  # (C_in,T,H,W)

        # Target mixes temporal median + spatial warp
        med = X.median(dim=1).values  # (C_in,H,W)
        filt = F.avg_pool2d(med, 3, 1, 1)
        # Project to C_out with fixed weights to ensure signal exists
        Wp = torch.randn(self.C_out, self.C_in, 1, 1) * 0.3
        Y = F.conv2d(filt.unsqueeze(0), Wp).squeeze(0)  # (C_out,H,W)
        return X.float(), Y.float()

# ---- Model ------------------------------------------------------------------

class ConvLSTMCell(nn.Module):
    """Minimal ConvLSTM cell. Keeps parameters small for a simple baseline."""
    def __init__(self, in_ch: int, hid_ch: int, k: int = 3):
        super().__init__()
        p = k // 2
        self.conv = nn.Conv2d(in_ch + hid_ch, 4 * hid_ch, k, padding=p)

    def forward(self, x, state):
        h, c = state
        y = self.conv(torch.cat([x, h], dim=1))
        i, f, g, o = torch.chunk(y, 4, dim=1)
        i = torch.sigmoid(i); f = torch.sigmoid(f); o = torch.sigmoid(o); g = torch.tanh(g)
        c = f * c + i * g
        h = o * torch.tanh(c)
        return h, c

class ConvLSTM(nn.Module):
    """Single-layer ConvLSTM unrolled over time dimension."""
    def __init__(self, in_ch: int, hid_ch: int):
        super().__init__()
        self.cell = ConvLSTMCell(in_ch, hid_ch)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, C, T, H, W)
        B, C, T, H, W = x.shape
        h = x.new_zeros(B, self.cell.conv.out_channels // 4, H, W)
        c = x.new_zeros(B, self.cell.conv.out_channels // 4, H, W)
        for t in range(T):
            h, c = self.cell(x[:, :, t], (h, c))
        return h  # (B, hidden, H, W)

class Encoder(nn.Module):
    def __init__(self, c_in: int, c_mid: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_mid, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(c_mid, c_mid, 3, padding=1),
            nn.GELU(),
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self, c_in: int, c_out: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c_in, c_in, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(c_in, c_out, 1),
        )

    def forward(self, x):
        return self.net(x)

class Video2ImageNet(nn.Module):
    """
    Minimal encoder(2D) -> ConvLSTM (time) -> decoder(2D).
    """
    def __init__(self, c_in: int, c_out: int, hidden: int = 64, enc: int = 64):
        super().__init__()
        self.encoder = Encoder(c_in, enc)
        self.agg = ConvLSTM(enc, hidden)
        self.decoder = Decoder(hidden, c_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B,C,T,H,W)
        B, C, T, H, W = x.shape
        x = x.permute(0, 2, 1, 3, 4)              # (B,T,C,H,W)
        x = x.reshape(B*T, C, H, W)               # merge time for 2D convs
        x = self.encoder(x)                       # (B*T, enc, H, W)
        x = x.view(B, T, -1, H, W).permute(0, 2, 1, 3, 4)  # (B, enc, T, H, W)
        h = self.agg(x)                           # (B, hidden, H, W)
        y = self.decoder(h)                       # (B, C_out, H, W)
        return y

# ---- Training ---------------------------------------------------------------

@dataclass
class TrainConfig:
    data_dir: Optional[Path]
    synth_len: int
    time: int
    height: int
    width: int
    c_in: int
    c_out: int
    batch: int
    epochs: int
    lr: float
    weight_decay: float
    loss: str
    amp: bool
    clip: float
    seed: int
    save_dir: Path

def make_dataloaders(cfg: TrainConfig) -> Tuple[DataLoader, DataLoader]:
    if cfg.data_dir:
        ds = SequenceToImageNPZ(cfg.data_dir)
        # infer shapes from first sample
        x0, y0 = ds[0]
        T = x0.shape[1]
        assert T == cfg.time or cfg.time <= 0, f"Data T={T} differs from --time={cfg.time}"
    else:
        ds = SyntheticSequenceToImage(
            length=cfg.synth_len, T=cfg.time, H=cfg.height, W=cfg.width, C_in=cfg.c_in, C_out=cfg.c_out
        )
    n_val = max(1, int(0.1 * len(ds)))
    n_train = len(ds) - n_val
    tr, va = random_split(ds, [n_train, n_val], generator=torch.Generator().manual_seed(cfg.seed))
    train_loader = DataLoader(tr, batch_size=cfg.batch, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(va, batch_size=cfg.batch, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

def get_loss(name: str):
    name = name.lower()
    if name == "l1":
        return nn.L1Loss()
    if name in ("mse", "l2"):
        return nn.MSELoss()
    if name in ("huber", "smoothl1"):
        return nn.SmoothL1Loss(beta=0.01)
    raise ValueError(f"Unknown loss {name}")

def train(cfg: TrainConfig) -> None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(cfg.seed)
    train_loader, val_loader = make_dataloaders(cfg)

    # Infer channels from data if not synthetic
    if cfg.data_dir:
        x0, y0 = next(iter(train_loader))
        c_in = x0.shape[0+1]  # (B,C,T,H,W) after batching; PyTorch loads (C,T,H,W)
        c_out = y0.shape[1]   # (B,C,H,W)
    else:
        c_in, c_out = cfg.c_in, cfg.c_out

    model = Video2ImageNet(c_in=c_in, c_out=c_out, hidden=64, enc=64).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)
    scaler = torch.cuda.amp.GradScaler(enabled=cfg.amp)

    best_val = float("inf")
    cfg.save_dir.mkdir(parents=True, exist_ok=True)

    criterion = get_loss(cfg.loss)

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        total = 0.0
        for X, Y in train_loader:
            X = X.to(device)  # (B,C,T,H,W)
            Y = Y.to(device)  # (B,C,H,W)
            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=cfg.amp):
                pred = model(X)
                loss = criterion(pred, Y)
            scaler.scale(loss).backward()
            if cfg.clip > 0:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), cfg.clip)  # keeps training stable
            scaler.step(opt); scaler.update()
            total += loss.item() * X.size(0)
        sched.step()
        train_loss = total / len(train_loader.dataset)

        # Validation
        model.eval()
        vloss = 0.0; vpsnr = 0.0
        with torch.no_grad():
            for X, Y in val_loader:
                X = X.to(device); Y = Y.to(device)
                pred = model(X)
                vloss += criterion(pred, Y).item() * X.size(0)
                vpsnr += psnr(pred.clamp(0,1), Y.clamp(0,1)) * X.size(0)
        vloss /= len(val_loader.dataset)
        vpsnr /= len(val_loader.dataset)

        print(f"epoch {epoch:03d} | train {train_loss:.4f} | val {vloss:.4f} | psnr {vpsnr:.2f}")

        if vloss < best_val:
            best_val = vloss
            ckpt = cfg.save_dir / "best.pt"
            torch.save({"model": model.state_dict(), "cfg": cfg.__dict__}, ckpt)
            print(f"saved {ckpt} (val {vloss:.4f})")

    # Export ONNX for downstream use
    try:
        model.eval()
        X, Y = next(iter(val_loader))
        dummy = X[:1].to(device)
        onnx_path = cfg.save_dir / "video2image.onnx"
        torch.onnx.export(
            model, dummy, onnx_path.as_posix(),
            input_names=["video"], output_names=["image"],
            opset_version=17, dynamic_axes={"video": {0: "B", 2: "T"}, "image": {0: "B"}}
        )
        print(f"exported ONNX to {onnx_path}")
    except Exception as e:
        print(f"ONNX export skipped: {e}")

# ---- CLI --------------------------------------------------------------------

def parse_args() -> TrainConfig:
    p = argparse.ArgumentParser(description="ConvLSTM baseline: sequence-of-feature-frames -> single image")
    p.add_argument("--data", type=str, default="", help="dir of .npz files with X:(T,H,W,C_in), Y:(H,W,C_out)")
    p.add_argument("--synth_len", type=int, default=1024, help="synthetic dataset size if no --data")
    p.add_argument("--time", type=int, default=8, help="frames per sample (T)")
    p.add_argument("--height", type=int, default=64)
    p.add_argument("--width", type=int, default=64)
    p.add_argument("--c_in", type=int, default=18, help="input channels (e.g., x,y + 16 features)")
    p.add_argument("--c_out", type=int, default=3, help="output channels")
    p.add_argument("--batch", type=int, default=8)
    p.add_argument("--epochs", type=int, default=20)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--weight_decay", type=float, default=0.01)
    p.add_argument("--loss", type=str, default="huber", choices=["l1","mse","huber"])
    p.add_argument("--no_amp", action="store_true", help="disable mixed precision")
    p.add_argument("--clip", type=float, default=1.0)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--save_dir", type=str, default="checkpoints")
    args = p.parse_args()

    return TrainConfig(
        data_dir=Path(args.data) if args.data else None,
        synth_len=args.synth_len, time=args.time, height=args.height, width=args.width,
        c_in=args.c_in, c_out=args.c_out,
        batch=args.batch, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay,
        loss=args.loss, amp=not args.no_amp, clip=args.clip, seed=args.seed,
        save_dir=Path(args.save_dir),
    )

if __name__ == "__main__":
    cfg = parse_args()
    train(cfg)

python src/video2image_main.py --data /path/to/npzs --time T --c_in N --c_out M

Response generated by TD Ai

2 Likes