[feat] refactored extension module (#5298)

* [feat] refactored extension module

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish

* polish
This commit is contained in:
Frank Lee
2024-01-25 17:01:48 +08:00
committed by GitHub
parent d7f8db8e21
commit 7cfed5f076
157 changed files with 1353 additions and 8966 deletions

View File

@@ -65,9 +65,9 @@ class TorchAdamKernel(AdamKernel):
class FusedAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
fused_optim = FusedOptimBuilder().load()
fused_optim = FusedOptimizerLoader().load()
self.fused_adam = fused_optim.multi_tensor_adam
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
@@ -91,7 +91,7 @@ class FusedAdamKernel(AdamKernel):
class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel import CPUAdamLoader
from colossalai.kernel.kernel_loader import CPUAdamLoader
cpu_optim = CPUAdamLoader().load()