Why does chunked dataset training give different results compared to full-batch training in my Siren model?

I’m implementing a Siren model for audio reconstruction using PyTorch. My first approach processes the entire dataset in a single batch, while my second approach loads and trains the dataset in smaller chunks to avoid memory overload.

Approach 1 (Full-batch training)

Here, I load the entire dataset at once and perform backpropagation on the full dataset in each iteration.

ach_audio = AudioFile('hello.wav')

dataloader = DataLoader(ach_audio, shuffle=False, batch_size=1, pin_memory=True, num_workers=0)

audio_siren = Siren(in_features=1, out_features=1, hidden_features=128, 
                    hidden_layers=3, first_omega_0=3000, outermost_linear=True).cuda()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
audio_siren.to(device)

optim = torch.optim.Adam(lr=1e-4, params=audio_siren.parameters())

model_input, ground_truth = next(iter(dataloader))
model_input, ground_truth = model_input.cuda(), ground_truth.cuda()

# Training loop
for step in range(2000):
    model_output, coords = audio_siren(model_input)
    loss = F.mse_loss(model_output, ground_truth)
    
    optim.zero_grad()
    loss.backward()
    optim.step()

Approach 2 (Chunked training to avoid memory overload)

Instead of training on the entire dataset at once, I process smaller chunks before performing a single optimization step per epoch. I use the following dataset class to load audio in chunks:

class AudioChunkDataset(Dataset):
    def __init__(self, filename, chunk_duration=5):
        self.rate, data = wavfile.read(filename)
        data = data.astype(np.float32)
        
        if len(data.shape) > 1:
            data = np.mean(data, axis=1)
        
        self.global_max = np.max(np.abs(data))
        self.data = data / self.global_max
        
        self.chunk_size = int(chunk_duration * self.rate)
        self.total_samples = len(self.data)
        self.num_chunks = int(np.ceil(self.total_samples / self.chunk_size))
        
        self.full_coords = torch.linspace(-1, 1, self.total_samples).view(-1, 1)
        
    def __len__(self):
        return self.num_chunks

    def __getitem__(self, idx):
        start_idx = idx * self.chunk_size
        end_idx = min(start_idx + self.chunk_size, self.total_samples)
        
        chunk_data = self.data[start_idx:end_idx]
        coords = self.full_coords[start_idx:end_idx]
        
        return coords, torch.Tensor(chunk_data).view(-1, 1)

This dataset is then used in training as follows:

# Load dataset
file_path = "hello.wav"
dataset = AudioChunkDataset(file_path, chunk_duration=5)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Siren(in_features=1, out_features=1, hidden_features=128, 
              hidden_layers=3, first_omega_0=3000, outermost_linear=True).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Training loop
total_epochs = 1000
best_loss = float('inf')

for epoch in range(total_epochs):
    epoch_loss = 0.0
    optimizer.zero_grad()
    
    for coords, amplitudes in dataloader:
        coords, amplitudes = coords.to(device), amplitudes.to(device)
        
        model_output, _ = model(coords)
        loss = F.mse_loss(model_output, amplitudes) / len(dataloader)  # Normalize loss
        
        loss.backward()  # Accumulate gradients
        epoch_loss += loss.item()
    
    optimizer.step()
    optimizer.zero_grad()
    
    print(f"Epoch {epoch+1}/{total_epochs}, Loss: {epoch_loss:.6f}")
    
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), "best_model.pth")
        print("Model saved.")

My question
I expected the second approach to give results similar to the first one, as it should only slow down computations and avoid memory overload while still seeing the full dataset before backpropagation. However, the results are noticeably different.

Why might the two approaches not give the same results?
Is there a bug in the second approach that prevents it from behaving as expected?
How can I modify the second approach to more closely match the first one while keeping the memory-efficient structure?

Any insights would be greatly appreciated, thank you hugging faces community for sharing!

1 Like

This is just one possibility, but I wonder if the fact that the wav files are split without considering the context has an effect.