problem:
1st training step OOMs after load_state: deepspeed stage 2 cpu offload
gpu memory usage significantly higher after load state
update:
tried:
model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(
model, train_dataloader, optimizer, lr_scheduler
)
model.load_checkpoint("path/to/input/dir/of/ckpt","pytorch_model")
as suggested here: load_state cuda out of memory · Issue #1707 · huggingface/accelerate · GitHub
still OOMs
Description of problem:
Code trains correctly, can use batch size up 22 without OOM
Saves state without errors
Loads state without errors
1st training step post load_state OOMs if batch size > 2
Anyone know how to correctly load_state with stage 2 cpu offload?
Relevant code:
save_state code:
#(post epoch)
if save_state == True:
accelerator.save_state(save_path)
load_state code:
if load_saved_state == True:
#load base_model from base_model pipeline (same as initial train loading)
noise_scheduler = DDPMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder="scheduler"
)
unet = UNet2DConditionModel.from_pretrained(
pretrained_model_name_or_path, subfolder="unet"
)
optimizer = bnb.optim.AdamW8bit(unet.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
learning_rate_scheduler,
optimizer=optimizer,
num_warmup_steps=num_warmup_update_steps * num_processes,
num_training_steps=total_num_update_steps,
)
#accelerate.prepare
unet, optimizer, train_dataloader, validation_loss_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, validation_loss_dataloader, lr_scheduler)
#load_state
accelerator.load_state(load_saved_state)
accelerate config:
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: ‘no’
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env:
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false