Issue with accelerator.backward(loss) freezing

I’m trying to accelerate my training script for custom diffuser model (based on diffusers.UNet2DModel) which works with 3D images.

I followed this example when adding accelerate features:

The problem is that whenever my code reaches accelerator.backward(loss) it freezes with no output, no error messages and soon after that my SSH connection (to a remote machine with GPUs) dies. I tried running the script in screen or tmux session but after reconnecting these sessions are dead too.

I run the accelerate test command and it runs with no issues. Also the code runs fine when I set --num_process to 1.

Here is my accelerate env and training script:

  • Accelerate version: 0.25.0
  • Platform: Linux-4.18.0-326.el8.x86_64-x86_64-with-glibc2.28
  • Python version: 3.11.3
  • Numpy version: 1.24.3
  • PyTorch version (GPU?): 2.0.1+cu117 (True)
  • PyTorch XPU available: False
  • PyTorch NPU available: False
  • System RAM: 31.15 GB
  • GPU type: NVIDIA TITAN RTX
  • Accelerate default config:
    - compute_environment: LOCAL_MACHINE
    - distributed_type: MULTI_GPU
    - mixed_precision: no
    - use_cpu: False
    - debug: True
    - num_processes: 2
    - machine_rank: 0
    - num_machines: 1
    - gpu_ids: all
    - rdzv_backend: static
    - same_network: True
    - main_training_function: main
    - downcast_bf16: no
    - tpu_use_cluster: False
    - tpu_use_sudo: False
    - tpu_env:
import os
import glob
import time
import math
import logging
from dataclasses import dataclass
from datetime import timedelta

import torch
import torch.nn.functional as F
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

import diffusers
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
from diffusers.utils import is_accelerate_version, is_tensorboard_available
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration

from monai.visualize import matshow3d
from monai.data import CacheDataset
from monai.utils import first
from monai.transforms import (
    LoadImage,
    EnsureChannelFirst,
    ToTensor,
    Lambda,
    Compose,
    RandFlip,
    RandRotate,
    RandZoom,
    ScaleIntensity,
    EnsureType,
    Transform,
    Resize,
)

from UNet3D_2D import UNet3DModel

@dataclass
class TrainingConfig:
    data_dir = "/data1/dose-3d-generative/data_med/PREPARED/FOR_AUG/ct_images_prostate_only_26fixed"
    image_size = 256
    scan_depth = 26
    batch_size = 1
    num_epochs = 10
    learning_rate = 1e-4
    lr_warmup_steps = 1000
    save_image_epochs = 100
    save_model_epochs = 50
    output_dir = "ct_256"
    seed = 0
    load_model_from_file = False
    gradient_accumulation_steps = 1
    logging_dir = "output_dir/logs"

def main():
    logger = get_logger(__name__, log_level="INFO")
    config = TrainingConfig()
    accelerator_project_config = ProjectConfiguration(
        project_dir=config.output_dir,
        logging_dir=config.logging_dir
    )

    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))  # a big number for high resolution or big dataset
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard",
        project_config=accelerator_project_config,
        kwargs_handlers=[kwargs]
    )

    if not is_tensorboard_available():
        raise ImportError("tensorboard not found")
    
    # TODO: add hooks for saving and loading model
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    if accelerator.is_local_main_process:
        diffusers.utils.logging.set_verbosity_info()
    else:
        diffusers.utils.logging.set_verbosity_error()

    # initialize model
    model = UNet3DModel(
        sample_size=config.image_size,
        sample_depth=config.scan_depth,
        in_channels=1,  # data are in grayscale, so always 1
        out_channels=1,  # data are in grayscale, so always 1
        layers_per_block=2,  # number of resnet blocks in each down_block/up_block
        block_out_channels=(32, 64, 64, 128, 256, 512, 512),
        down_block_types=(
            "DownBlock3D",
            "DownBlock2D",
            "DownBlock3D",
            "DownBlock2D",
            "DownBlock3D",
            "AttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types=(
            "UpBlock3D",
            "AttnUpBlock3D",
            "UpBlock2D",
            "UpBlock3D",
            "UpBlock2D",
            "UpBlock3D",
            "UpBlock3D",
        ),
        norm_num_groups=32,
        dropout=0.0,
    )

    with open(config.output_dir + "/config.txt", 'w') as fp:
        fp.write(f"{model.block_out_channels}\n{model.down_block_types}\n{model.up_block_types}\n")

    if config.load_model_from_file:
        model.load_state_dict(torch.load(config.output_dir + '/model'))

    # initialize scheduler
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

    # initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    # prepare dataset
    # TODO: split train/validate when Accelerate works
    images_pattern = os.path.join(config.data_dir, "*.nii.gz")
    images = sorted(glob.glob(images_pattern))

    transforms = Compose(
        [
            LoadImage(image_only=True),
            EnsureChannelFirst(),
            Resize((config.image_size, config.image_size, config.scan_depth)),
            ScaleIntensity(),
            RandFlip(spatial_axis=1, prob=0.5),
            RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
            EnsureType()
        ]
    )

    dataset = CacheDataset(images, transforms)
    train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=config.batch_size, num_workers=10, shuffle=True)

    def prepare_batch(batch_data, device=None, non_blocking=False):
        t = Compose(
            [
                Lambda(lambda t: (t * 2) - 1),
            ]
        )
        return t(batch_data.permute(0, 1, 4, 2, 3).to(device=device, non_blocking=non_blocking))
    
    logger.info(f"Dataset size: {len(dataset)}")

    # initialize learning rate scheduler
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps,
        num_training_steps=(len(train_loader) * config.num_epochs),
    )

    # prepare everyting with accelerator
    model, optimizer, train_loader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_loader, lr_scheduler
    )

    # initialize trackers on main process
    if accelerator.is_main_process:
        run = os.path.split(__file__)[-1].split(".")[0]
        accelerator.init_trackers(run)

    total_batch_size = config.batch_size * accelerator.num_processes * config.gradient_accumulation_steps
    num_update_steps_per_epoch = math.ceil(len(train_loader) / config.gradient_accumulation_steps)
    max_train_steps = config.num_epochs * num_update_steps_per_epoch

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(dataset)}")
    logger.info(f"  Num Epochs = {config.num_epochs}")
    logger.info(f"  Instantaneous batch size per device = {config.batch_size}")
    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
    logger.info(f"  Gradient Accumulation steps = {config.gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {max_train_steps}")

    global_step = 0
    first_epoch = 0

    # TODO: loading weights and states from previous save

    # train the model
    for epoch in range(first_epoch, config.num_epochs):
        model.train()
        progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        for step, batch in enumerate(train_loader):
            print(f'step = {step}, device = {accelerator.device}')
            clean_images = prepare_batch(batch)
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]
            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device
            ).long() # generates a tensor of shape (1,) with random int from range [0, 1000)
    
            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            with accelerator.accumulate(model):
                # predict the noise residual
                model_output = model(noisy_images, timesteps).sample
                loss = F.mse_loss(model_output.float(), noise.float())
                print(f'before backward step - {accelerator.device}')
                accelerator.backward(loss)
                print(f'after backward step - {accelerator.device}')
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            # check if the accelerator has performed an optimization step
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
        
        progress_bar.close()
        accelerator.wait_for_everyone()

    accelerator.end_training()

if __name__ == '__main__':
    main()

I’d appreciate any help.
Thank you all