mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -66,7 +66,8 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
exp_avg_sq = torch.rand(p_data.shape)
|
||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||
|
||||
from colossalai.kernel import cpu_optim
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
|
||||
|
@@ -46,7 +46,8 @@ def torch_adam_update(
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
fused_adam = fused_optim.multi_tensor_adam
|
||||
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
Reference in New Issue
Block a user