mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[builder] unified cpu_optim fused_optim inferface (#2190)
This commit is contained in:
@@ -76,13 +76,8 @@ class HybridAdam(NVMeOptimizer):
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
from colossalai._C import cpu_optim, fused_optim
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
from colossalai.kernel import cpu_optim, fused_optim
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
|
Reference in New Issue
Block a user