[Feature] auto-cast optimizers to distributed version (#5746)

* auto-cast optimizers to distributed

* fix galore casting

* logger

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
Edenzzzz
2024-05-24 17:24:16 +08:00
committed by GitHub
parent 2fc85abf43
commit 5f8c0a0ac3
13 changed files with 61 additions and 31 deletions

View File

@@ -43,12 +43,13 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
is_paged (`bool`, defaults to `False`):
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
"""
def __init__(
self,
params,
lr=1e-3,
lr=1e-2,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=1e-2,
@@ -57,6 +58,7 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
percentile_clipping=100,
block_wise=True,
is_paged=False,
args=None,
):
super().__init__(
"adam",
@@ -65,13 +67,14 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
betas,
eps,
weight_decay,
nbits,
None,
min_8bit_size,
percentile_clipping,
block_wise,
optim_bits=nbits,
args=args,
min_8bit_size=min_8bit_size,
percentile_clipping=percentile_clipping,
block_wise=block_wise,
is_paged=is_paged,
)
self.tp_size = 1
self.dp_size = 1
self.is_dist = {}