mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[feat] refactored extension module (#5298)
* [feat] refactored extension module * polish * polish * polish * polish * polish * polish * polish * polish * polish * polish
This commit is contained in:
@@ -65,9 +65,9 @@ class TorchAdamKernel(AdamKernel):
|
||||
class FusedAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.kernel.kernel_loader import FusedOptimizerLoader
|
||||
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
fused_optim = FusedOptimizerLoader().load()
|
||||
self.fused_adam = fused_optim.multi_tensor_adam
|
||||
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@@ -91,7 +91,7 @@ class FusedAdamKernel(AdamKernel):
|
||||
class CPUAdamKernel(AdamKernel):
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel import CPUAdamLoader
|
||||
from colossalai.kernel.kernel_loader import CPUAdamLoader
|
||||
|
||||
cpu_optim = CPUAdamLoader().load()
|
||||
|
||||
|
Reference in New Issue
Block a user