I wasn’t able to find any documentation on this, but if I want to use gradient checkpointing with FSDP training (assuming the model.supports_gradient_checkpointing
is True
), do i need to manually apply the wrapping like so
method 1
model: PretrainedModel = ...
check_fn = lambda submodule: isinstance(submodule, ...)
non_reentrant_wrapper = partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
model = accelerator.prepare(model)
...
or does model.gradient_checkpointing_enable
take care of that, so instead
method 2
model: PretrainedModel = ...
model.gradient_checkpointing_enable()
model = accelerator.prepare(model)
...
which is the correct method? And in the case that the model doesn’t not support gradient checkpointing, can i manually apply gradient checkpointing via method 1?
Another question:
The gradient checkpointing and FSDP wrapping policy should to apply to the same layers, which is fine if using the TRANSFORMER_BASED_WRAP
policy, but what about using SIZE_BASED_WRAP
or NO_WRAP
?
Or is my understanding incorrect, and checkpointing logic and FSDP wrapping policy can be independent of one another