mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[Feature] auto-cast optimizers to distributed version (#5746)
* auto-cast optimizers to distributed * fix galore casting * logger --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu>
This commit is contained in:
@@ -43,12 +43,13 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
|
||||
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
|
||||
is_paged (`bool`, defaults to `False`):
|
||||
Whether the optimizer is a paged optimizer (handle memory spike via CPU-GPU transfer) or not.
|
||||
args (dict, optional): quantization-related arguments. If passed, will override all quantization args above.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=1e-3,
|
||||
lr=1e-2,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
weight_decay=1e-2,
|
||||
@@ -57,6 +58,7 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
|
||||
percentile_clipping=100,
|
||||
block_wise=True,
|
||||
is_paged=False,
|
||||
args=None,
|
||||
):
|
||||
super().__init__(
|
||||
"adam",
|
||||
@@ -65,13 +67,14 @@ class DistGaloreAwamW(DistributedOptim, Optimizer2State):
|
||||
betas,
|
||||
eps,
|
||||
weight_decay,
|
||||
nbits,
|
||||
None,
|
||||
min_8bit_size,
|
||||
percentile_clipping,
|
||||
block_wise,
|
||||
optim_bits=nbits,
|
||||
args=args,
|
||||
min_8bit_size=min_8bit_size,
|
||||
percentile_clipping=percentile_clipping,
|
||||
block_wise=block_wise,
|
||||
is_paged=is_paged,
|
||||
)
|
||||
|
||||
self.tp_size = 1
|
||||
self.dp_size = 1
|
||||
self.is_dist = {}
|
||||
|
Reference in New Issue
Block a user