mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)
* [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test
This commit is contained in:
@@ -146,8 +146,7 @@ class CPUAdam(NVMeOptimizer):
|
||||
assert state["exp_avg"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
assert state["exp_avg_sq"].device.type == "cpu", "exp_avg should stay on cpu"
|
||||
self._pre_update(p, "exp_avg", "exp_avg_sq")
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
if p.grad.dtype is torch.bfloat16 or p.dtype is not torch.float:
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1 ** state["step"]
|
||||
bias_correction2 = 1 - beta2 ** state["step"]
|
||||
|
Reference in New Issue
Block a user