[builder] unified cpu_optim fused_optim inferface (#2190)

This commit is contained in:
Jiarui Fang
2022-12-23 20:57:41 +08:00
committed by GitHub
parent 9587b080ba
commit 355ffb386e
9 changed files with 28 additions and 50 deletions

View File

@@ -76,13 +76,8 @@ class HybridAdam(NVMeOptimizer):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
from colossalai._C import cpu_optim, fused_optim
except ImportError:
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
cpu_optim = CPUAdamBuilder().load()
from colossalai.kernel import cpu_optim, fused_optim
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = fused_optim.multi_tensor_adam