After some work I realized that the trainer is the one that receives the string paged_adamw_32bit
. So I check where its used. I couldn’t find it but by luck my search option wasn’t brittle to caps so I found something nearly identical. I notced we needed bits and bytes lib by tim teddmers/HF so I realized this obj doesn’t exist in transformers (I assume). Due to this code:
elif args.optim in [
OptimizerNames.ADAMW_BNB,
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
OptimizerNames.PAGED_LION_8BIT,
]:
try:
from bitsandbytes.optim import AdamW, Lion
is_paged = False
optim_bits = 32
optimizer_cls = None
additional_optim_kwargs = adam_kwargs
if "paged" in args.optim:
is_paged = True
if "8bit" in args.optim:
optim_bits = 8
if "adam" in args.optim:
optimizer_cls = AdamW
elif "lion" in args.optim:
optimizer_cls = Lion
additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
optimizer_kwargs.update(additional_optim_kwargs)
optimizer_kwargs.update(bnb_kwargs)
except ImportError:
raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
```
but getting the optimizer I want is a static method in Trainer that only needs your training args object. So you can do:
def get_paged_adamw_32bit_manual(args: TrainingArguments) → Tuple[Any, Any]:
optimizer, scheduler = Trainer.get_optimizer_cls_and_kwargs(args)
return optimizer, scheduler