[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -2,7 +2,7 @@ from typing import Any, Optional
import torch
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
from colossalai.utils import multi_tensor_applier
from .cpu_adam import CPUAdam
@@ -85,7 +85,7 @@ class HybridAdam(CPUAdam):
nvme_offload_dir,
)
if torch.cuda.is_available():
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])