0 Loss. HF trainer integration

My model gets along fine during training in pytorch with my training loop but has difficulty with HF trainer… Trying to integrate it but getting weird 0 loss on eval… Preliminary. If anyone has experience with this specific type of thing and sees anything that pops out at them please let me know.

###################################

‘’’ python

import base64, gzip, logging, math, time, warnings, numpy as np, torch, torch.nn as nn
import torch.nn.functional as F, torch.optim as optim, torch.utils.checkpoint as checkpoint, torchaudio, torchaudio.transforms as at
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from torch import Tensor
from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function
import os, librosa, torch, re, warnings, transformers, evaluate, spacy, ginza, MeCab, deepl, soundfile as sf, neologdn
from datasets import Dataset, load_dataset, DatasetDict, Audio, concatenate_datasets, load_from_disk, IterableDataset, interleave_datasets
from transformers import WhisperForConditionalGeneration, WhisperModel, GenerationConfig, WhisperTokenizer, WhisperProcessor, Seq2SeqTrainer, WhisperConfig, Seq2SeqTrainingArguments, AutoTokenizer, pipeline, AutoModelForSpeechSeq2Seq, TrainingArguments, TrainerState, TrainerControl, TrainerCallback, WhisperFeatureExtractor, WhisperTokenizerFast
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Seq2SeqTrainer, PretrainedConfig, PreTrainedModel
try:
from torch.nn.functional import scaled_dot_product_attention

SDPA_AVAILABLE = True

except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
warnings.filterwarnings(action=“ignore”)
warnings.warn = lambda *args, **kwargs: None
device = “cuda” if torch.cuda.is_available() else “cpu”
torch_dtype = torch.float32 if torch.cuda.is_available() else torch.float16
import tensorboard
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int

class LayerNorm(nn.Module):
def init(self, num_groups, num_channels, eps=1e-6):
super(LayerNorm, self).init()
self.num_groups = num_groups
self.eps = eps
self.g = nn.Parameter(torch.ones(num_channels))
self.b = nn.Parameter(torch.zeros(num_channels))
self.scale = nn.Parameter(torch.ones(num_channels))

def forward(self, x):
    N, C, *dims = x.shape
    if len(dims) == 2:
        x = x.view(N, self.num_groups, C // self.num_groups, *dims)
        mean = x.mean(dim=(2, 3), keepdim=True)  # Group Normalization
        var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + self.eps)
        x = x.view(N, C, *dims)
    else:
        x = x.view(N, self.num_groups, C // self.num_groups, -1)
        mean = x.mean(dim=2, keepdim=True)  # Group Normalization
        var = x.var(dim=2, keepdim=True, unbiased=False)
        x = (x - mean) / torch.sqrt(var + self.eps)
        x = x.view(N, C, *dims)
    
    x = self.g * x + self.b
    x = x * self.scale  # Layer-Scale Normalization
    return x

class Linear(nn.Linear):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
self.reset_parameters()

def reset_parameters(self):
    nn.init.xavier_uniform_(self.weight)  # Custom initialization
    if self.bias is not None:
        nn.init.zeros_(self.bias)

def forward(self, x: Tensor) -> Tensor:
    weight = self.weight.to(x.dtype)
    bias = None if self.bias is None else self.bias.to(x.dtype)
    torch.nn.utils.clip_grad_norm_(weight, max_norm=1.0)  # Apply gradient clipping to the weights
    return F.linear(x, weight, bias)

class Conv1d(nn.Conv1d):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
self.reset_parameters()

def reset_parameters(self):
    nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
    if self.bias is not None:
        nn.init.zeros_(self.bias)

def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
    return super()._conv_forward(
        x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
    )

class RotaryEmbedding(nn.Module):
def init(self, dim, base=10000):
super().init()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer(‘inv_freq’, inv_freq)

def rotate_queries_or_keys(self, x):
    sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(x.shape[1], device=x.device), self.inv_freq)  # x shape: [batch_size, seq_len, n_head, head_dim]
    sin = sinusoid_inp.sin()[None, :, None, :]  # Shape: [1, seq_len, 1, head_dim]
    cos = sinusoid_inp.cos()[None, :, None, :]  # Shape: [1, seq_len, 1, head_dim]
    x1, x2 = x[..., ::2], x[..., 1::2]
    x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
    return x

class SinusoidalFeatures:
def init(self, n_ctx, n_state):
self.n_ctx = n_ctx
self.n_state = n_state
self.features = self.sinusoidal_features(n_ctx, n_state)

@staticmethod
def sinusoidal_features(n_ctx, n_state):
    position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
    features = torch.zeros(n_ctx, n_state)
    features[:, 0::2] = torch.sin(position * div_term)
    features[:, 1::2] = torch.cos(position * div_term)
    return features

def __call__(self):
    return self.features

class LearnedSinusoidalEmbeddings(nn.Module): # sinusofeatures(n_ctx, n_state)
def init(self, n_ctx, n_state, gradient_checkpointing=False):
super().init()
self.n_ctx = n_ctx
self.n_state = n_state
self.gradient_checkpointing = gradient_checkpointing

    sinusoidal_embeddings = SinusoidalFeatures(n_ctx, n_state)()
    self.positional_embeddings = nn.Parameter(sinusoidal_embeddings)

def forward(self, positions):
    if self.gradient_checkpointing:
        position_embeddings = checkpoint.checkpoint(lambda x: self.positional_embeddings[x], positions)
    else:
        position_embeddings = self.positional_embeddings[positions]

    position_embeddings = F.normalize(position_embeddings, p=2, dim=-1)  # Normalize positional embeddings
    return position_embeddings

class HybridAttention(nn.Module):
def init(self, n_state: int, n_head: int, dropout_rate=0.1):
super().init()
self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate) # Local and Global Multi-Head Attention
self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
self.ln_local = LayerNorm(num_groups=1, num_channels=n_state) # Layer Norms
self.ln_global = LayerNorm(num_groups=1, num_channels=n_state)
self.dropout = nn.Dropout(dropout_rate) # Dropout

    self.window_size = 5  # Define the local attention window size, adjust as needed

def forward(self, x: Tensor):
    x_local = self.ln_local(x)  # Apply Layer Norms
    x_global = self.ln_global(x)

    x_local = x_local.permute(1, 0, 2)  # Transpose for PyTorch MultiheadAttention: [seq_len, batch_size, n_state]
    x_global = x_global.permute(1, 0, 2)

    local_out = self.sliding_window_attention(x_local)  # Local Attention, create a custom function to handle the sliding window
    global_out, _ = self.global_attn(x_global, x_global, x_global)  # Global Attention

    combined_out = local_out + global_out  # Combine outputs, [seq_len, batch_size, n_state]
    combined_out = combined_out.permute(1, 0, 2)  # Transpose back: [batch_size, seq_len, n_state]

    return self.dropout(combined_out)

def sliding_window_attention(self, x):
    seq_len, batch_size, n_state = x.size()  # x: [seq_len, batch_size, n_state]
    window_size = self.window_size

    output = torch.zeros_like(x)
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        query = x[i:i+1, :, :]  # [1, batch_size, n_state]
        key = x[start:end, :, :]  # [window, batch_size, n_state]
        value = x[start:end, :, :]  # [window, batch_size, n_state]

        attn_output, _ = self.local_attn(query, key, value)
        output[i:i+1, :, :] = attn_output

    return output

class MultiHeadAttention(nn.Module):
use_sdpa = True

def __init__(self, n_state: int, n_head: int, dropout_rate=0.1, gradient_checkpointing=False):
    super().__init__()
    self.n_head = n_head
    self.n_state = n_state
    self.head_dim = n_state // n_head

    self.query = Linear(n_state, n_state)
    self.key = Linear(n_state, n_state, bias=False)
    self.value = Linear(n_state, n_state)
    self.out = Linear(n_state, n_state)

    self.rotary_emb = RotaryEmbedding(dim=self.head_dim)  # Rotary Embedding Initialization
    self.temperature = nn.Parameter(torch.ones(1) * (self.head_dim ** -0.5))  # Temperature scaling
    self.dropout = nn.Dropout(dropout_rate)  # Dropout
    self.attn_ln = LayerNorm(num_groups=1, num_channels=n_state)  # LayerNorm
    self.gradient_checkpointing = gradient_checkpointing

def forward(self, x: Tensor, xa: Optional[Tensor] = None,
            mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
    x_norm = self.attn_ln(x)  # Apply layer norm first

    q = self.query(x_norm)
    if kv_cache is None or xa is None or self.key not in kv_cache:
        k = self.key(x_norm if xa is None else xa)
        v = self.value(x_norm if xa is None else xa)
    else:
        k = kv_cache[self.key]
        v = kv_cache[self.value]

    q = q.view(q.shape[0], q.shape[1], self.n_head, -1)  # Apply rotary embeddings using Einstein summation
    k = k.view(k.shape[0], k.shape[1], self.n_head, -1)

    q = self.rotary_emb.rotate_queries_or_keys(q)
    k = self.rotary_emb.rotate_queries_or_keys(k)

    q = q.view(q.shape[0], q.shape[1], -1)
    k = k.view(k.shape[0], k.shape[1], -1)

    wv, qk = self.qkv_attention(q, k, v, mask)

    return self.out(wv) + x, qk  # Residual connection

def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor,
                  mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
    n_batch, n_ctx, n_state = q.shape
    scale = self.temperature

    q = q.view(n_batch, n_ctx, self.n_head, self.head_dim).permute(0, 2, 1, 3)
    k = k.view(n_batch, k.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)
    v = v.view(n_batch, v.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)

    if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
        a = scaled_dot_product_attention(q, k, v,
                                         is_causal=mask is not None and n_ctx > 1)
        out = a.permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
        qk = None
    else:
        qk = (q * scale) @ (k.transpose(-2, -1) * scale)
        if mask is not None:
            qk += mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = F.softmax(qk, dim=-1).to(q.dtype)
        w = self.dropout(w)
        out = (w @ v).permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
        qk = qk.detach()
    return out, qk

class ResidualAttentionBlock(nn.Module):
def init(self, n_state: int, n_head: int, use_hybrid_attention: bool = False,
cross_attention: bool = False, dropout_rate=0.1, gradient_checkpointing=False):
super().init()

    self.use_hybrid_attention = use_hybrid_attention

    if self.use_hybrid_attention:
        self.attn = HybridAttention(n_state, n_head, dropout_rate=dropout_rate)
    else:
        self.attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing)

    self.attn_ln = LayerNorm(num_groups=1, num_channels=n_state)

    self.cross_attention = cross_attention
    if self.cross_attention:
        self.cross_attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing)
        self.cross_attn_ln = LayerNorm(num_groups=1, num_channels=n_state)

    n_mlp = n_state * 4

    self.mlp = nn.Sequential(  # MLP block
        Linear(n_state, n_mlp),
        LayerNorm(num_groups=1, num_channels=n_mlp),
        nn.GELU(),
        nn.Dropout(p=dropout_rate),
        Linear(n_mlp, n_state)
    )
    self.mlp_ln = LayerNorm(num_groups=1, num_channels=n_state)

    self.gradient_checkpointing = gradient_checkpointing

def forward(self, x: Tensor, xa: Optional[Tensor] = None,
            mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
    attn_input = self.attn_ln(x)  # Apply attention with residual connection
    if self.gradient_checkpointing:
        if self.use_hybrid_attention:
            attn_out = x + checkpoint.checkpoint(self.attn, attn_input)
        else:
            attn_out = x + checkpoint.checkpoint(self.attn, attn_input, mask, kv_cache)[0]
    else:
        if self.use_hybrid_attention:
            attn_out = x + self.attn(attn_input)
        else:
            attn_out = x + self.attn(attn_input, mask=mask, kv_cache=kv_cache)[0]

    if self.cross_attention and xa is not None:  # Cross-attention remains unchanged
        cross_attn_input = self.cross_attn_ln(attn_out)
        if self.gradient_checkpointing:
            attn_out = attn_out + checkpoint.checkpoint(self.cross_attn, cross_attn_input, xa, kv_cache)[0]
        else:
            attn_out = attn_out + self.cross_attn(cross_attn_input, xa, kv_cache=kv_cache)[0]

    mlp_input = self.mlp_ln(attn_out)  # Apply MLP with residual connection
    if self.gradient_checkpointing:
        mlp_out = attn_out + checkpoint.checkpoint(self.mlp, mlp_input)
    else:
        mlp_out = attn_out + self.mlp(mlp_input)

    return mlp_out

class AudioEncoder(nn.Module):
def init(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().init()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.dropout = nn.Dropout(dropout_rate)
self.gradient_checkpointing = gradient_checkpointing

    self.blocks = nn.ModuleList([
        ResidualAttentionBlock(
            n_state, n_head,
            use_hybrid_attention=use_hybrid_attention,
            dropout_rate=dropout_rate,
            gradient_checkpointing=gradient_checkpointing
        )
        for _ in range(n_layer)
    ])
    self.ln_post = LayerNorm(num_groups=1, num_channels=n_state)

def forward(self, x: torch.Tensor):
    x = F.gelu(self.conv1(x))
    x = self.dropout(x)
    x = F.gelu(self.conv2(x))
    x = self.dropout(x)
    x = x.permute(0, 2, 1)

    for block in self.blocks:
        if self.gradient_checkpointing:
            x = checkpoint.checkpoint(block, x)
        else:
            x = block(x)

    x = self.ln_post(x)
    return x

class TextDecoder(nn.Module):
def init(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().init()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, gradient_checkpointing=gradient_checkpointing)
self.gradient_checkpointing = gradient_checkpointing

    self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
        ResidualAttentionBlock(
            n_state, n_head,
            use_hybrid_attention=use_hybrid_attention,
            cross_attention=True,
            dropout_rate=dropout_rate,
            gradient_checkpointing=gradient_checkpointing
        )
        for _ in range(n_layer)
    ])
    self.ln_post = LayerNorm(num_groups=1, num_channels=n_state)

    mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
    self.register_buffer('mask', mask, persistent=False)

def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
    positions = torch.arange(x.shape[1], device=x.device)
    pos_emb = self.positional_embedding(positions).unsqueeze(0)  # Shape: (1, seq_len, n_state)
    x = self.token_embedding(x) + pos_emb
    x = x.to(xa.dtype)

    for block in self.blocks:
        if self.gradient_checkpointing:
            x = checkpoint.checkpoint(block, x, xa, self.mask, kv_cache)
        else:
            x = block(x, xa, self.mask, kv_cache)

    x = self.ln_post(x)
    logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
    return logits

class WhisperConfig(PretrainedConfig):
def init(self, **kwargs):
super().init(**kwargs)
self.n_vocab = kwargs.get(“n_vocab”, 51865)
self.n_mels = kwargs.get(“n_mels”, 80)
self.n_audio_ctx = kwargs.get(“n_audio_ctx”, 1500)
self.n_audio_state = kwargs.get(“n_audio_state”, 1024)
self.n_audio_head = kwargs.get(“n_audio_head”, 16)
self.n_audio_layer = kwargs.get(“n_audio_layer”, 8)
self.n_text_ctx = kwargs.get(“n_text_ctx”, 448)
self.n_text_state = kwargs.get(“n_text_state”, 1024)
self.n_text_head = kwargs.get(“n_text_head”, 16)
self.n_text_layer = kwargs.get(“n_text_layer”, 8)
self.dropout_rate = kwargs.get(“dropout_rate”, 0.1)
self.gradient_checkpointing = kwargs.get(“gradient_checkpointing”, False)
self.decoder_start_token_id = kwargs.get(“decoder_start_token_id”, 50258)
self.use_hybrid_attention = kwargs.get(“use_hybrid_attention”, False)

class Whisper(PreTrainedModel):
config_class = WhisperConfig

def __init__(self, config, dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
    super().__init__(config)
    self.config = config
    self.encoder = AudioEncoder(
        self.config.n_mels,
        self.config.n_audio_ctx,
        self.config.n_audio_state,
        self.config.n_audio_head,
        self.config.n_audio_layer,
        dropout_rate=dropout_rate,
        gradient_checkpointing=gradient_checkpointing,
        use_hybrid_attention=use_hybrid_attention 
    )
    self.decoder = TextDecoder(
        self.config.n_vocab,
        self.config.n_text_ctx,
        self.config.n_text_state,
        self.config.n_text_head,
        self.config.n_text_layer,
        dropout_rate=dropout_rate,
        gradient_checkpointing=gradient_checkpointing,
        use_hybrid_attention=use_hybrid_attention  
    )
    self.loss_fn = nn.CrossEntropyLoss()

    all_heads = torch.zeros(
        self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
    )
    all_heads[self.config.n_text_layer // 2 :] = True
    self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
    array = np.frombuffer(
        gzip.decompress(base64.b85decode(dump)), dtype=bool
    ).copy()
    mask = torch.from_numpy(array).reshape(
        self.config.n_text_layer, self.config.n_text_head
    )
    self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
    return self.encoder(mel)

def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
    return self.decoder(tokens, audio_features)

def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
    input_features = input_features.to(self.device)
    labels = labels.to(self.device) if labels is not None else None
    
    audio_features = self.encoder(input_features)
    logits = self.decoder(labels, audio_features)
    
    loss = None
    if labels is not None:
        # Flatten the tokens and labels for the loss computation
        logits = logits.view(-1, self.config.n_vocab)
        labels = labels.view(-1)
        loss = self.loss_fn(logits, labels)
    
    return {
        "loss": loss,
        "labels": logits,
        "input_features": audio_features
    } if loss is not None else {
        "labels": logits,
        "input_features": audio_features
    }

@property
def device(self):
    return next(self.parameters()).device

@property
def is_multilingual(self):
    return self.config.n_vocab >= 51865

@property
def num_languages(self):
    return self.config.n_vocab - 51765 - int(self.is_multilingual)

def install_kv_cache_hooks(self, cache: Optional[dict] = None):
    cache = {**cache} if cache is not None else {}
    hooks = []

    def save_to_cache(module, _, output):
        if module not in cache or output.shape[1] > self.config.n_text_ctx:
            cache[module] = output
        else:
            cache[module] = torch.cat([cache[module], output], dim=1).detach()
        return cache[module]

    def install_hooks(layer: nn.Module):
        if isinstance(layer, MultiHeadAttention):
            hooks.append(layer.key.register_forward_hook(save_to_cache))
            hooks.append(layer.value.register_forward_hook(save_to_cache))

    self.decoder.apply(install_hooks)
    return cache, hooks

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

config = WhisperConfig()

model = Whisper(config, use_hybrid_attention=False, gradient_checkpointing=True)

edit

###################################################################################################

#Define dummy inputs

batch_size = 2
n_mels = 80
audio_seq_len = 3000
n_vocab = 51865
text_seq_len = 256

dummy_input_features = torch.randn(batch_size, n_mels, audio_seq_len).cuda()
dummy_labels = torch.randint(0, n_vocab, (batch_size, text_seq_len)).cuda()

dimensions = ModelDimensions(
n_mels=80,
n_audio_ctx=audio_seq_len,
n_audio_state=1024,
n_audio_head=16,
n_audio_layer=8,
n_vocab=51865,
n_text_ctx=text_seq_len,
n_text_state=1024,
n_text_head=16,
n_text_layer=8
)

Initialize model with hybrid attention enabled

model = Whisper(dimensions, use_hybrid_attention=True, gradient_checkpointing=True).cuda()

Forward pass

output = model(input_features=dummy_input_features, labels=dummy_labels)

Print outputs

if output:
print(f"Logits shape: {output[‘tokens’].shape}“)
print(f"Audio features shape: {output[‘audio_features’].shape}”)
else:
print(“Model output is None. Please check the implementation.”)

HF trainer adaption

class WhisperConfig(PretrainedConfig):
def init(self, **kwargs):
super().init(**kwargs)
self.n_vocab = kwargs.get(“n_vocab”, 51865)
self.n_mels = kwargs.get(“n_mels”, 80)
self.n_audio_ctx = kwargs.get(“n_audio_ctx”, 1500)
self.n_audio_state = kwargs.get(“n_audio_state”, 1024)
self.n_audio_head = kwargs.get(“n_audio_head”, 16)
self.n_audio_layer = kwargs.get(“n_audio_layer”, 8)
self.n_text_ctx = kwargs.get(“n_text_ctx”, 448)
self.n_text_state = kwargs.get(“n_text_state”, 1024)
self.n_text_head = kwargs.get(“n_text_head”, 16)
self.n_text_layer = kwargs.get(“n_text_layer”, 8)
self.dropout_rate = kwargs.get(“dropout_rate”, 0.1)
self.gradient_checkpointing = kwargs.get(“gradient_checkpointing”, False)
self.decoder_start_token_id = kwargs.get(“decoder_start_token_id”, 50258)
self.use_hybrid_attention = kwargs.get(“use_hybrid_attention”, False)

class Whisper(PreTrainedModel):
config_class = WhisperConfig

def __init__(self, config, dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
    super().__init__(config)
    self.config = config
    self.encoder = AudioEncoder(
        self.config.n_mels,
        self.config.n_audio_ctx,
        self.config.n_audio_state,
        self.config.n_audio_head,
        self.config.n_audio_layer,
        dropout_rate=dropout_rate,
        gradient_checkpointing=gradient_checkpointing,
        use_hybrid_attention=use_hybrid_attention 
    )
    self.decoder = TextDecoder(
        self.config.n_vocab,
        self.config.n_text_ctx,
        self.config.n_text_state,
        self.config.n_text_head,
        self.config.n_text_layer,
        dropout_rate=dropout_rate,
        gradient_checkpointing=gradient_checkpointing,
        use_hybrid_attention=use_hybrid_attention  
    )
    self.loss_fn = nn.CrossEntropyLoss()

    all_heads = torch.zeros(
        self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
    )
    all_heads[self.config.n_text_layer // 2 :] = True
    self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)

def set_alignment_heads(self, dump: bytes):
    array = np.frombuffer(
        gzip.decompress(base64.b85decode(dump)), dtype=bool
    ).copy()
    mask = torch.from_numpy(array).reshape(
        self.config.n_text_layer, self.config.n_text_head
    )
    self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)

def embed_audio(self, mel: torch.Tensor):
    return self.encoder(mel)

def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
    return self.decoder(tokens, audio_features)  

def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
    input_features = input_features.to(self.device)
    labels = labels.to(self.device) if labels is not None else None
    
    audio_features = self.encoder(input_features)
    logits = self.decoder(labels, audio_features)
    
    loss = None
    if labels is not None:
        # Flatten the tokens and labels for the loss computation
        logits = logits.view(-1, self.config.n_vocab)
        labels = labels.view(-1)
        loss = self.loss_fn(logits, labels)
    
    return {
        "loss": loss,
        "labels": logits,
        "input_features": audio_features
    } if loss is not None else {
        "labels": logits,
        "input_features": audio_features
    }

@property
def device(self):
    return next(self.parameters()).device

@property
def is_multilingual(self):
    return self.config.n_vocab >= 51865

@property
def num_languages(self):
    return self.config.n_vocab - 51765 - int(self.is_multilingual)

def install_kv_cache_hooks(self, cache: Optional[dict] = None):
    cache = {**cache} if cache is not None else {}
    hooks = []

    def save_to_cache(module, _, output):
        if module not in cache or output.shape[1] > self.config.n_text_ctx:
            cache[module] = output
        else:
            cache[module] = torch.cat([cache[module], output], dim=1).detach()
        return cache[module]

    def install_hooks(layer: nn.Module):
        if isinstance(layer, MultiHeadAttention):
            hooks.append(layer.key.register_forward_hook(save_to_cache))
            hooks.append(layer.value.register_forward_hook(save_to_cache))

    self.decoder.apply(install_hooks)
    return cache, hooks

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

config = WhisperConfig()

Initialize model with hybrid attention enabled and move to GPU

model = Whisper(config, use_hybrid_attention=False, gradient_checkpointing=True)

‘’’

2024-11-20T08:00:00Z

1 Like