Hello everyone, I’m reimplemented Qwen2-VL model to understand inner workings, inspired by Umar Jamil. I am encountering memory spikes that lead to CUDA out of memory errors in my version, the code looks identical but I have no idea why my version uses extra memory.
The original implementation from transformers
stabilizes memory usage after the third attention layer. In contrast, my version shows continual memory growth. I’ve included relevant sections of both my code and the original for comparison, as well as the memory usage outputs below.
Does anyone have insights into why my code might be causing extra memory usage and how to achieve similar memory stability to the original? Any advice would be highly appreciated.
My Code:
def print_cuda_memory():
"""Utility function to print CUDA memory statistics"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3 # Convert to GB
reserved = torch.cuda.memory_reserved() / 1024**3
max_allocated = torch.cuda.max_memory_allocated() / 1024**3
print(f"\n=== CUDA Memory ===")
print(f"Allocated: {allocated:.2f} GB")
print(f"Reserved: {reserved:.2f} GB")
print(f"Max Allocated: {max_allocated:.2f} GB")
print("="*40)
def apply_rotary_pos_emb_vision(tensor, freqs):
org_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos().float()
sin = freqs.sin().float()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0)
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0)
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(org_dtype)
return output
class Qwen2VLRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config = None,
):
super().__init__()
self.rope_kwargs = {}
if config is None:
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def update(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached:
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self.update(position_ids, device=x.device)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
# [3, BatchSize, 1, Positions]
position_ids_expanded = position_ids[:, :, None, :].float()
device_type = x.device.type
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos *= self.attention_scaling
sin *= self.attention_scaling
cos, sin = cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
print_cuda_memory()
return cos, sin
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seq_len):
seq = torch.arange(seq_len,
device=self.inv_freq.device,
dtype=torch.float32)
freqs = torch.outer(seq, self.inv_freq)
print_cuda_memory()
return freqs
class VisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
# Project and split QKV in a memory-efficient way
qkv = self.qkv(hidden_states)
del hidden_states
torch.cuda.empty_cache()
qkv = qkv.reshape(seq_length, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(1, 0, 2, 3)
q, k, v = qkv.unbind(0)
del qkv
torch.cuda.empty_cache()
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
attention_mask = torch.zeros(
[1, seq_length, seq_length],
device=q.device,
dtype=torch.bool
)
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
attention_mask[..., start_idx:end_idx, start_idx:end_idx] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
print_cuda_memory()
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=0.0,
scale=self.scale,
is_causal=False # Important: we're using block-diagonal attention
)
torch.cuda.empty_cache()
attn_output = attn_output.transpose(0, 1).contiguous()
attn_output = attn_output.reshape(seq_length, -1)
output = self.proj(attn_output)
del attn_output, q, k, v, attention_mask
torch.cuda.empty_cache()
return output
This is the transformers
implementation of Qwen2-VL:
class Qwen2VLRotaryEmbedding(nn.Module):
def __init__(
self,
dim=None,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
rope_type="default",
config: Optional[Qwen2VLConfig] = None,
):
super().__init__()
self.rope_kwargs = {}
if config is None:
(..)
self.rope_kwargs = {
"rope_type": rope_type,
"factor": scaling_factor,
"dim": dim,
"base": base,
"max_position_embeddings": max_position_embeddings,
}
self.rope_type = rope_type
self.max_seq_len_cached = max_position_embeddings
self.original_max_seq_len = max_position_embeddings
else:
if config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
def _dynamic_frequency_update(self, position_ids, device):
seq_len = torch.max(position_ids) + 1
if seq_len > self.max_seq_len_cached: # growth
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device, seq_len=seq_len, **self.rope_kwargs
)
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
self.max_seq_len_cached = seq_len
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
self.max_seq_len_cached = self.original_max_seq_len
@torch.no_grad()
def forward(self, x, position_ids):
if "dynamic" in self.rope_type:
self._dynamic_frequency_update(position_ids, device=x.device)
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
cos = cos * self.attention_scaling
sin = sin * self.attention_scaling
print_cuda_memory()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
print_cuda_memory()
return freqs
class VisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
print_cuda_memory()
attention_mask = torch.zeros([1, seq_length, seq_length], device=q.device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
Original memory usage compared to my code’s memory usage, as you can see, original implementation’s memory usage does not increase after third attention layer, but mine does.
Any ideas on why my implementation keeps consuming more memory, or what differences might be leading to the discrepancy in behavior?