mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
@@ -76,12 +77,8 @@ class CPUAdam(NVMeOptimizer):
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import colossalai._C.cpu_optim
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||
adamw_mode)
|
||||
cpu_adam = CPUAdamBuilder().load()
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def torch_adam_update(self,
|
||||
data,
|
||||
|
Reference in New Issue
Block a user