Gradient checkpointing + FSDP

I wasn’t able to find any documentation on this, but if I want to use gradient checkpointing with FSDP training (assuming the model.supports_gradient_checkpointing is True), do i need to manually apply the wrapping like so

method 1

model: PretrainedModel = ...

check_fn = lambda submodule: isinstance(submodule, ...)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

model = accelerator.prepare(model) 
...

or does model.gradient_checkpointing_enable take care of that, so instead

method 2

model: PretrainedModel = ...
model.gradient_checkpointing_enable()
model = accelerator.prepare(model) 
...

which is the correct method? And in the case that the model doesn’t not support gradient checkpointing, can i manually apply gradient checkpointing via method 1?

Another question:
The gradient checkpointing and FSDP wrapping policy should to apply to the same layers, which is fine if using the TRANSFORMER_BASED_WRAP policy, but what about using SIZE_BASED_WRAP or NO_WRAP?

Or is my understanding incorrect, and checkpointing logic and FSDP wrapping policy can be independent of one another

1 Like

I applied the method 2:

model: PretrainedModel = ...
model.gradient_checkpointing_enable()
model = accelerator.prepare(model)
1 Like