My model gets along fine during training in pytorch with my training loop but has difficulty with HF trainer… Trying to integrate it but getting weird 0 loss on eval… Preliminary. If anyone has experience with this specific type of thing and sees anything that pops out at them please let me know.
###################################
‘’’ python
import base64, gzip, logging, math, time, warnings, numpy as np, torch, torch.nn as nn
import torch.nn.functional as F, torch.optim as optim, torch.utils.checkpoint as checkpoint, torchaudio, torchaudio.transforms as at
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from torch import Tensor
from whisper.decoding import decode as decode_function
from whisper.decoding import detect_language as detect_language_function
from whisper.transcribe import transcribe as transcribe_function
import os, librosa, torch, re, warnings, transformers, evaluate, spacy, ginza, MeCab, deepl, soundfile as sf, neologdn
from datasets import Dataset, load_dataset, DatasetDict, Audio, concatenate_datasets, load_from_disk, IterableDataset, interleave_datasets
from transformers import WhisperForConditionalGeneration, WhisperModel, GenerationConfig, WhisperTokenizer, WhisperProcessor, Seq2SeqTrainer, WhisperConfig, Seq2SeqTrainingArguments, AutoTokenizer, pipeline, AutoModelForSpeechSeq2Seq, TrainingArguments, TrainerState, TrainerControl, TrainerCallback, WhisperFeatureExtractor, WhisperTokenizerFast
from transformers import TrainingArguments, Trainer, Seq2SeqTrainer, Seq2SeqTrainingArguments, Seq2SeqTrainer, PretrainedConfig, PreTrainedModel
try:
from torch.nn.functional import scaled_dot_product_attention
SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False
warnings.filterwarnings(action=“ignore”)
warnings.warn = lambda *args, **kwargs: None
device = “cuda” if torch.cuda.is_available() else “cpu”
torch_dtype = torch.float32 if torch.cuda.is_available() else torch.float16
import tensorboard
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.Module):
def init(self, num_groups, num_channels, eps=1e-6):
super(LayerNorm, self).init()
self.num_groups = num_groups
self.eps = eps
self.g = nn.Parameter(torch.ones(num_channels))
self.b = nn.Parameter(torch.zeros(num_channels))
self.scale = nn.Parameter(torch.ones(num_channels))
def forward(self, x):
N, C, *dims = x.shape
if len(dims) == 2:
x = x.view(N, self.num_groups, C // self.num_groups, *dims)
mean = x.mean(dim=(2, 3), keepdim=True) # Group Normalization
var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
x = x.view(N, C, *dims)
else:
x = x.view(N, self.num_groups, C // self.num_groups, -1)
mean = x.mean(dim=2, keepdim=True) # Group Normalization
var = x.var(dim=2, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
x = x.view(N, C, *dims)
x = self.g * x + self.b
x = x * self.scale # Layer-Scale Normalization
return x
class Linear(nn.Linear):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.weight) # Custom initialization
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x: Tensor) -> Tensor:
weight = self.weight.to(x.dtype)
bias = None if self.bias is None else self.bias.to(x.dtype)
torch.nn.utils.clip_grad_norm_(weight, max_norm=1.0) # Apply gradient clipping to the weights
return F.linear(x, weight, bias)
class Conv1d(nn.Conv1d):
def init(self, *args, **kwargs):
super().init(*args, **kwargs)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')
if self.bias is not None:
nn.init.zeros_(self.bias)
def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
class RotaryEmbedding(nn.Module):
def init(self, dim, base=10000):
super().init()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer(‘inv_freq’, inv_freq)
def rotate_queries_or_keys(self, x):
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(x.shape[1], device=x.device), self.inv_freq) # x shape: [batch_size, seq_len, n_head, head_dim]
sin = sinusoid_inp.sin()[None, :, None, :] # Shape: [1, seq_len, 1, head_dim]
cos = sinusoid_inp.cos()[None, :, None, :] # Shape: [1, seq_len, 1, head_dim]
x1, x2 = x[..., ::2], x[..., 1::2]
x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
return x
class SinusoidalFeatures:
def init(self, n_ctx, n_state):
self.n_ctx = n_ctx
self.n_state = n_state
self.features = self.sinusoidal_features(n_ctx, n_state)
@staticmethod
def sinusoidal_features(n_ctx, n_state):
position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))
features = torch.zeros(n_ctx, n_state)
features[:, 0::2] = torch.sin(position * div_term)
features[:, 1::2] = torch.cos(position * div_term)
return features
def __call__(self):
return self.features
class LearnedSinusoidalEmbeddings(nn.Module): # sinusofeatures(n_ctx, n_state)
def init(self, n_ctx, n_state, gradient_checkpointing=False):
super().init()
self.n_ctx = n_ctx
self.n_state = n_state
self.gradient_checkpointing = gradient_checkpointing
sinusoidal_embeddings = SinusoidalFeatures(n_ctx, n_state)()
self.positional_embeddings = nn.Parameter(sinusoidal_embeddings)
def forward(self, positions):
if self.gradient_checkpointing:
position_embeddings = checkpoint.checkpoint(lambda x: self.positional_embeddings[x], positions)
else:
position_embeddings = self.positional_embeddings[positions]
position_embeddings = F.normalize(position_embeddings, p=2, dim=-1) # Normalize positional embeddings
return position_embeddings
class HybridAttention(nn.Module):
def init(self, n_state: int, n_head: int, dropout_rate=0.1):
super().init()
self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate) # Local and Global Multi-Head Attention
self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)
self.ln_local = LayerNorm(num_groups=1, num_channels=n_state) # Layer Norms
self.ln_global = LayerNorm(num_groups=1, num_channels=n_state)
self.dropout = nn.Dropout(dropout_rate) # Dropout
self.window_size = 5 # Define the local attention window size, adjust as needed
def forward(self, x: Tensor):
x_local = self.ln_local(x) # Apply Layer Norms
x_global = self.ln_global(x)
x_local = x_local.permute(1, 0, 2) # Transpose for PyTorch MultiheadAttention: [seq_len, batch_size, n_state]
x_global = x_global.permute(1, 0, 2)
local_out = self.sliding_window_attention(x_local) # Local Attention, create a custom function to handle the sliding window
global_out, _ = self.global_attn(x_global, x_global, x_global) # Global Attention
combined_out = local_out + global_out # Combine outputs, [seq_len, batch_size, n_state]
combined_out = combined_out.permute(1, 0, 2) # Transpose back: [batch_size, seq_len, n_state]
return self.dropout(combined_out)
def sliding_window_attention(self, x):
seq_len, batch_size, n_state = x.size() # x: [seq_len, batch_size, n_state]
window_size = self.window_size
output = torch.zeros_like(x)
for i in range(seq_len):
start = max(0, i - window_size)
end = min(seq_len, i + window_size + 1)
query = x[i:i+1, :, :] # [1, batch_size, n_state]
key = x[start:end, :, :] # [window, batch_size, n_state]
value = x[start:end, :, :] # [window, batch_size, n_state]
attn_output, _ = self.local_attn(query, key, value)
output[i:i+1, :, :] = attn_output
return output
class MultiHeadAttention(nn.Module):
use_sdpa = True
def __init__(self, n_state: int, n_head: int, dropout_rate=0.1, gradient_checkpointing=False):
super().__init__()
self.n_head = n_head
self.n_state = n_state
self.head_dim = n_state // n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.rotary_emb = RotaryEmbedding(dim=self.head_dim) # Rotary Embedding Initialization
self.temperature = nn.Parameter(torch.ones(1) * (self.head_dim ** -0.5)) # Temperature scaling
self.dropout = nn.Dropout(dropout_rate) # Dropout
self.attn_ln = LayerNorm(num_groups=1, num_channels=n_state) # LayerNorm
self.gradient_checkpointing = gradient_checkpointing
def forward(self, x: Tensor, xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
x_norm = self.attn_ln(x) # Apply layer norm first
q = self.query(x_norm)
if kv_cache is None or xa is None or self.key not in kv_cache:
k = self.key(x_norm if xa is None else xa)
v = self.value(x_norm if xa is None else xa)
else:
k = kv_cache[self.key]
v = kv_cache[self.value]
q = q.view(q.shape[0], q.shape[1], self.n_head, -1) # Apply rotary embeddings using Einstein summation
k = k.view(k.shape[0], k.shape[1], self.n_head, -1)
q = self.rotary_emb.rotate_queries_or_keys(q)
k = self.rotary_emb.rotate_queries_or_keys(k)
q = q.view(q.shape[0], q.shape[1], -1)
k = k.view(k.shape[0], k.shape[1], -1)
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv) + x, qk # Residual connection
def qkv_attention(self, q: Tensor, k: Tensor, v: Tensor,
mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
n_batch, n_ctx, n_state = q.shape
scale = self.temperature
q = q.view(n_batch, n_ctx, self.n_head, self.head_dim).permute(0, 2, 1, 3)
k = k.view(n_batch, k.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)
v = v.view(n_batch, v.shape[1], self.n_head, self.head_dim).permute(0, 2, 1, 3)
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(q, k, v,
is_causal=mask is not None and n_ctx > 1)
out = a.permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
qk = None
else:
qk = (q * scale) @ (k.transpose(-2, -1) * scale)
if mask is not None:
qk += mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
w = self.dropout(w)
out = (w @ v).permute(0, 2, 1, 3).reshape(n_batch, n_ctx, n_state)
qk = qk.detach()
return out, qk
class ResidualAttentionBlock(nn.Module):
def init(self, n_state: int, n_head: int, use_hybrid_attention: bool = False,
cross_attention: bool = False, dropout_rate=0.1, gradient_checkpointing=False):
super().init()
self.use_hybrid_attention = use_hybrid_attention
if self.use_hybrid_attention:
self.attn = HybridAttention(n_state, n_head, dropout_rate=dropout_rate)
else:
self.attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing)
self.attn_ln = LayerNorm(num_groups=1, num_channels=n_state)
self.cross_attention = cross_attention
if self.cross_attention:
self.cross_attn = MultiHeadAttention(n_state, n_head, dropout_rate=dropout_rate, gradient_checkpointing=gradient_checkpointing)
self.cross_attn_ln = LayerNorm(num_groups=1, num_channels=n_state)
n_mlp = n_state * 4
self.mlp = nn.Sequential( # MLP block
Linear(n_state, n_mlp),
LayerNorm(num_groups=1, num_channels=n_mlp),
nn.GELU(),
nn.Dropout(p=dropout_rate),
Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(num_groups=1, num_channels=n_state)
self.gradient_checkpointing = gradient_checkpointing
def forward(self, x: Tensor, xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None, kv_cache: Optional[dict] = None):
attn_input = self.attn_ln(x) # Apply attention with residual connection
if self.gradient_checkpointing:
if self.use_hybrid_attention:
attn_out = x + checkpoint.checkpoint(self.attn, attn_input)
else:
attn_out = x + checkpoint.checkpoint(self.attn, attn_input, mask, kv_cache)[0]
else:
if self.use_hybrid_attention:
attn_out = x + self.attn(attn_input)
else:
attn_out = x + self.attn(attn_input, mask=mask, kv_cache=kv_cache)[0]
if self.cross_attention and xa is not None: # Cross-attention remains unchanged
cross_attn_input = self.cross_attn_ln(attn_out)
if self.gradient_checkpointing:
attn_out = attn_out + checkpoint.checkpoint(self.cross_attn, cross_attn_input, xa, kv_cache)[0]
else:
attn_out = attn_out + self.cross_attn(cross_attn_input, xa, kv_cache=kv_cache)[0]
mlp_input = self.mlp_ln(attn_out) # Apply MLP with residual connection
if self.gradient_checkpointing:
mlp_out = attn_out + checkpoint.checkpoint(self.mlp, mlp_input)
else:
mlp_out = attn_out + self.mlp(mlp_input)
return mlp_out
class AudioEncoder(nn.Module):
def init(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().init()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.dropout = nn.Dropout(dropout_rate)
self.gradient_checkpointing = gradient_checkpointing
self.blocks = nn.ModuleList([
ResidualAttentionBlock(
n_state, n_head,
use_hybrid_attention=use_hybrid_attention,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing
)
for _ in range(n_layer)
])
self.ln_post = LayerNorm(num_groups=1, num_channels=n_state)
def forward(self, x: torch.Tensor):
x = F.gelu(self.conv1(x))
x = self.dropout(x)
x = F.gelu(self.conv2(x))
x = self.dropout(x)
x = x.permute(0, 2, 1)
for block in self.blocks:
if self.gradient_checkpointing:
x = checkpoint.checkpoint(block, x)
else:
x = block(x)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def init(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int,
dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().init()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx, n_state, gradient_checkpointing=gradient_checkpointing)
self.gradient_checkpointing = gradient_checkpointing
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList([
ResidualAttentionBlock(
n_state, n_head,
use_hybrid_attention=use_hybrid_attention,
cross_attention=True,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing
)
for _ in range(n_layer)
])
self.ln_post = LayerNorm(num_groups=1, num_channels=n_state)
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer('mask', mask, persistent=False)
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
positions = torch.arange(x.shape[1], device=x.device)
pos_emb = self.positional_embedding(positions).unsqueeze(0) # Shape: (1, seq_len, n_state)
x = self.token_embedding(x) + pos_emb
x = x.to(xa.dtype)
for block in self.blocks:
if self.gradient_checkpointing:
x = checkpoint.checkpoint(block, x, xa, self.mask, kv_cache)
else:
x = block(x, xa, self.mask, kv_cache)
x = self.ln_post(x)
logits = (x @ self.token_embedding.weight.to(x.dtype).T).float()
return logits
class WhisperConfig(PretrainedConfig):
def init(self, **kwargs):
super().init(**kwargs)
self.n_vocab = kwargs.get(“n_vocab”, 51865)
self.n_mels = kwargs.get(“n_mels”, 80)
self.n_audio_ctx = kwargs.get(“n_audio_ctx”, 1500)
self.n_audio_state = kwargs.get(“n_audio_state”, 1024)
self.n_audio_head = kwargs.get(“n_audio_head”, 16)
self.n_audio_layer = kwargs.get(“n_audio_layer”, 8)
self.n_text_ctx = kwargs.get(“n_text_ctx”, 448)
self.n_text_state = kwargs.get(“n_text_state”, 1024)
self.n_text_head = kwargs.get(“n_text_head”, 16)
self.n_text_layer = kwargs.get(“n_text_layer”, 8)
self.dropout_rate = kwargs.get(“dropout_rate”, 0.1)
self.gradient_checkpointing = kwargs.get(“gradient_checkpointing”, False)
self.decoder_start_token_id = kwargs.get(“decoder_start_token_id”, 50258)
self.use_hybrid_attention = kwargs.get(“use_hybrid_attention”, False)
class Whisper(PreTrainedModel):
config_class = WhisperConfig
def __init__(self, config, dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().__init__(config)
self.config = config
self.encoder = AudioEncoder(
self.config.n_mels,
self.config.n_audio_ctx,
self.config.n_audio_state,
self.config.n_audio_head,
self.config.n_audio_layer,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing,
use_hybrid_attention=use_hybrid_attention
)
self.decoder = TextDecoder(
self.config.n_vocab,
self.config.n_text_ctx,
self.config.n_text_state,
self.config.n_text_head,
self.config.n_text_layer,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing,
use_hybrid_attention=use_hybrid_attention
)
self.loss_fn = nn.CrossEntropyLoss()
all_heads = torch.zeros(
self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
)
all_heads[self.config.n_text_layer // 2 :] = True
self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.config.n_text_layer, self.config.n_text_head
)
self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_features = input_features.to(self.device)
labels = labels.to(self.device) if labels is not None else None
audio_features = self.encoder(input_features)
logits = self.decoder(labels, audio_features)
loss = None
if labels is not None:
# Flatten the tokens and labels for the loss computation
logits = logits.view(-1, self.config.n_vocab)
labels = labels.view(-1)
loss = self.loss_fn(logits, labels)
return {
"loss": loss,
"labels": logits,
"input_features": audio_features
} if loss is not None else {
"labels": logits,
"input_features": audio_features
}
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.config.n_vocab >= 51865
@property
def num_languages(self):
return self.config.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.config.n_text_ctx:
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
config = WhisperConfig()
model = Whisper(config, use_hybrid_attention=False, gradient_checkpointing=True)
edit
###################################################################################################
#Define dummy inputs
batch_size = 2
n_mels = 80
audio_seq_len = 3000
n_vocab = 51865
text_seq_len = 256
dummy_input_features = torch.randn(batch_size, n_mels, audio_seq_len).cuda()
dummy_labels = torch.randint(0, n_vocab, (batch_size, text_seq_len)).cuda()
dimensions = ModelDimensions(
n_mels=80,
n_audio_ctx=audio_seq_len,
n_audio_state=1024,
n_audio_head=16,
n_audio_layer=8,
n_vocab=51865,
n_text_ctx=text_seq_len,
n_text_state=1024,
n_text_head=16,
n_text_layer=8
)
Initialize model with hybrid attention enabled
model = Whisper(dimensions, use_hybrid_attention=True, gradient_checkpointing=True).cuda()
Forward pass
output = model(input_features=dummy_input_features, labels=dummy_labels)
Print outputs
if output:
print(f"Logits shape: {output[‘tokens’].shape}“)
print(f"Audio features shape: {output[‘audio_features’].shape}”)
else:
print(“Model output is None. Please check the implementation.”)
HF trainer adaption
class WhisperConfig(PretrainedConfig):
def init(self, **kwargs):
super().init(**kwargs)
self.n_vocab = kwargs.get(“n_vocab”, 51865)
self.n_mels = kwargs.get(“n_mels”, 80)
self.n_audio_ctx = kwargs.get(“n_audio_ctx”, 1500)
self.n_audio_state = kwargs.get(“n_audio_state”, 1024)
self.n_audio_head = kwargs.get(“n_audio_head”, 16)
self.n_audio_layer = kwargs.get(“n_audio_layer”, 8)
self.n_text_ctx = kwargs.get(“n_text_ctx”, 448)
self.n_text_state = kwargs.get(“n_text_state”, 1024)
self.n_text_head = kwargs.get(“n_text_head”, 16)
self.n_text_layer = kwargs.get(“n_text_layer”, 8)
self.dropout_rate = kwargs.get(“dropout_rate”, 0.1)
self.gradient_checkpointing = kwargs.get(“gradient_checkpointing”, False)
self.decoder_start_token_id = kwargs.get(“decoder_start_token_id”, 50258)
self.use_hybrid_attention = kwargs.get(“use_hybrid_attention”, False)
class Whisper(PreTrainedModel):
config_class = WhisperConfig
def __init__(self, config, dropout_rate=0.1, gradient_checkpointing=False, use_hybrid_attention=False):
super().__init__(config)
self.config = config
self.encoder = AudioEncoder(
self.config.n_mels,
self.config.n_audio_ctx,
self.config.n_audio_state,
self.config.n_audio_head,
self.config.n_audio_layer,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing,
use_hybrid_attention=use_hybrid_attention
)
self.decoder = TextDecoder(
self.config.n_vocab,
self.config.n_text_ctx,
self.config.n_text_state,
self.config.n_text_head,
self.config.n_text_layer,
dropout_rate=dropout_rate,
gradient_checkpointing=gradient_checkpointing,
use_hybrid_attention=use_hybrid_attention
)
self.loss_fn = nn.CrossEntropyLoss()
all_heads = torch.zeros(
self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool
)
all_heads[self.config.n_text_layer // 2 :] = True
self.register_buffer('alignment_heads', all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.config.n_text_layer, self.config.n_text_head
)
self.register_buffer('alignment_heads', mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(self, input_features: torch.Tensor, labels: torch.Tensor = None) -> Dict[str, torch.Tensor]:
input_features = input_features.to(self.device)
labels = labels.to(self.device) if labels is not None else None
audio_features = self.encoder(input_features)
logits = self.decoder(labels, audio_features)
loss = None
if labels is not None:
# Flatten the tokens and labels for the loss computation
logits = logits.view(-1, self.config.n_vocab)
labels = labels.view(-1)
loss = self.loss_fn(logits, labels)
return {
"loss": loss,
"labels": logits,
"input_features": audio_features
} if loss is not None else {
"labels": logits,
"input_features": audio_features
}
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.config.n_vocab >= 51865
@property
def num_languages(self):
return self.config.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.config.n_text_ctx:
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
config = WhisperConfig()
Initialize model with hybrid attention enabled and move to GPU
model = Whisper(config, use_hybrid_attention=False, gradient_checkpointing=True)
‘’’
2024-11-20T08:00:00Z