mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)
* [kernel] support pure fp16 for cpu adam (#4896) * [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919) * [kernel] fix cpu adam * [test] update gemini optim test
This commit is contained in:
@@ -13,9 +13,7 @@ from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
_FUSED_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
(torch.bfloat16, torch.float),
|
||||
(torch.float, torch.bfloat16),
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
]
|
||||
@@ -23,7 +21,6 @@ _FUSED_ALLOWED_P_G_TYPES = [
|
||||
_CPU_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.half),
|
||||
(torch.float, torch.float),
|
||||
(torch.half, torch.float),
|
||||
(torch.half, torch.half),
|
||||
]
|
||||
|
||||
@@ -138,8 +135,8 @@ def check_adam_kernel(
|
||||
master_exp_avg_sq = torch.zeros_like(master_p)
|
||||
p = master_p.clone().to(p_dtype)
|
||||
g = master_g.clone().to(g_dtype)
|
||||
exp_avg = master_exp_avg.clone()
|
||||
exp_avg_sq = master_exp_avg_sq.clone()
|
||||
exp_avg = master_exp_avg.clone().to(p_dtype)
|
||||
exp_avg_sq = master_exp_avg_sq.clone().to(p_dtype)
|
||||
|
||||
for step in range(1, 1 + n_steps):
|
||||
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
|
||||
|
Reference in New Issue
Block a user