1st training step OOMs after load_state: deepspeed stage 2 cpu offload

problem:
1st training step OOMs after load_state: deepspeed stage 2 cpu offload
gpu memory usage significantly higher after load state

update:
tried:

model, train_dataloader, optimizer, lr_scheduler = accelerator.prepare(
model, train_dataloader, optimizer, lr_scheduler
)
model.load_checkpoint("path/to/input/dir/of/ckpt","pytorch_model")

as suggested here: load_state cuda out of memory · Issue #1707 · huggingface/accelerate · GitHub
still OOMs

Description of problem:
Code trains correctly, can use batch size up 22 without OOM
Saves state without errors
Loads state without errors
1st training step post load_state OOMs if batch size > 2

Anyone know how to correctly load_state with stage 2 cpu offload?

Relevant code:
save_state code:

#(post epoch)

if save_state == True:
    accelerator.save_state(save_path)

load_state code:

if load_saved_state == True:
	#load base_model from base_model pipeline (same as initial train loading)
        noise_scheduler = DDPMScheduler.from_pretrained(
            pretrained_model_name_or_path, subfolder="scheduler"
        )
        unet = UNet2DConditionModel.from_pretrained(
            pretrained_model_name_or_path, subfolder="unet"
        )
        optimizer = bnb.optim.AdamW8bit(unet.parameters(), lr=learning_rate)
        lr_scheduler = get_scheduler(
                learning_rate_scheduler,
                optimizer=optimizer,
                num_warmup_steps=num_warmup_update_steps * num_processes,
                num_training_steps=total_num_update_steps,
            )
           
         #accelerate.prepare
         unet, optimizer, train_dataloader, validation_loss_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, validation_loss_dataloader, lr_scheduler)
         
         #load_state
         accelerator.load_state(load_saved_state)

accelerate config:
compute_environment: LOCAL_MACHINE
debug: true
deepspeed_config:
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: ‘no’
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env:
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false