mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[builder] runtime adam and fused_optim builder (#2184)
This commit is contained in:
@@ -69,8 +69,12 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
try:
|
||||
import colossalai._C.cpu_optim
|
||||
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
print("use prebuilt CPUAdamOptimizer")
|
||||
except:
|
||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||
from colossalai.kernel.op_builder.cpu_adam import CPUAdamBuilder
|
||||
lib = CPUAdamBuilder().load()
|
||||
cpu_adam_op = lib.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
print("build CPUAdamOptimizer at runtime")
|
||||
|
||||
cpu_adam_op.step(
|
||||
step,
|
||||
@@ -115,3 +119,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cpu_adam()
|
||||
|
Reference in New Issue
Block a user