Handling Floating-Point Precision Issues with Large Matrix Operations in PyTorch

Hi everyone,

I’m working with a transformer model where I need to perform operations on large matrices, and I’ve run into some floating-point precision issues. Here’s a simplified version of what I’m doing:

import torch
import torch.nn as nn

# Example dimensions
batch_size = 1
seq_length = 197
hidden_size = 768

linear_module = nn.Linear(hidden_size, hidden_size)
input_tensor = torch.randn(batch_size, seq_length, hidden_size)

# Create a masked input and compute outputs
input_mask = torch.rand_like(input_tensor)
input_masked = input_tensor + input_mask

out_og = linear_module(input_tensor)
out_masked = linear_module(input_masked)

# Reconstruct the original output by removing the effect of the mask
out_rec = out_masked - linear_module(input_mask) + linear_module.bias

When I compare out_og and out_rec using torch.allclose(out_og, out_rec), it returns True for small matrices but False for larger matrices (e.g., hidden_size = 768). It seems that small numerical differences are accumulating with larger matrix dimensions, and I suspect this is due to floating-point precision limits.

Has anyone else encountered similar issues with accumulated precision errors in large matrix operations?

1 Like