[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -184,6 +184,7 @@ class GaLoreAdamW8bit(Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
Example:
"""
@@ -200,6 +201,7 @@ class GaLoreAdamW8bit(Optimizer2State):
percentile_clipping=100,
block_wise=True,
is_paged=False,
args=None,
):
super().__init__(
"adam",
@@ -208,11 +210,11 @@ class GaLoreAdamW8bit(Optimizer2State):
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
optim_bits=nbits,
args=args,
min_8bit_size=min_8bit_size,
percentile_clipping=percentile_clipping,
block_wise=block_wise,
is_paged=is_paged,
)