[builder] unified cpu_optim fused_optim inferface (#2190)

This commit is contained in:
Jiarui Fang
2022-12-23 20:57:41 +08:00
committed by GitHub
parent 9587b080ba
commit 355ffb386e
9 changed files with 28 additions and 50 deletions

View File

@@ -46,13 +46,8 @@ def torch_adam_update(
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, step, p_dtype, g_dtype):
try:
import colossalai._C.fused_optim
fused_adam = colossalai._C.fused_optim.multi_tensor_adam
except:
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
fused_adam = fused_optim.multi_tensor_adam
from colossalai.kernel import fused_optim
fused_adam = fused_optim.multi_tensor_adam
dummy_overflow_buf = torch.cuda.IntTensor([0])