[gemini] support amp o3 for gemini (#4872)

* [gemini] support no reuse fp16 chunk

* [gemini] support no master weight for optim

* [gemini] support no master weight for gemini ddp

* [test] update gemini tests

* [test] update gemini tests

* [plugin] update gemini plugin

* [test] fix gemini checkpointio test

* [test] fix gemini checkpoint io
This commit is contained in:
Hongxin Liu
2023-10-12 10:39:08 +08:00
committed by GitHub
parent c1fab951e7
commit df63564184
15 changed files with 222 additions and 114 deletions

View File

@@ -132,9 +132,6 @@ class CPUAdam(NVMeOptimizer):
target_device = p.device
if len(state) == 0:
state["step"] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums
state["exp_avg"] = torch.zeros_like(p, device=target_device)
# gradient variances
@@ -149,7 +146,8 @@ class CPUAdam(NVMeOptimizer):
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
self._pre_update(p, "exp_avg", "exp_avg_sq")
if p.grad.dtype is torch.bfloat16:
# FIXME(ver217): CPU adam kernel only supports fp32 states now
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1 ** state["step"]
bias_correction2 = 1 - beta2 ** state["step"]