mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[gemini] support amp o3 for gemini (#4872)
* [gemini] support no reuse fp16 chunk * [gemini] support no master weight for optim * [gemini] support no master weight for gemini ddp * [test] update gemini tests * [test] update gemini tests * [plugin] update gemini plugin * [test] fix gemini checkpointio test * [test] fix gemini checkpoint io
This commit is contained in:
@@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer):
|
||||
target_device = p.device
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state["exp_avg"] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
@@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer):
|
||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
Reference in New Issue
Block a user