[kernel] move all symlinks of kernel to colossalai._C (#1971)

This commit is contained in:
ver217
2022-11-17 13:42:33 +08:00
committed by GitHub
parent 7e24b9b9ee
commit f8a7148dec
27 changed files with 463 additions and 322 deletions

View File

@@ -1,9 +1,11 @@
import math
from typing import Optional
import torch
from colossalai.registry import OPTIMIZERS
from .nvme_optimizer import NVMeOptimizer
from typing import Optional
@OPTIMIZERS.register_module
@@ -11,7 +13,7 @@ class CPUAdam(NVMeOptimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
@@ -44,7 +46,7 @@ class CPUAdam(NVMeOptimizer):
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
nvme_offload_fraction (float, optional): Fraction of optimizer states to be offloaded to NVMe. Defaults to 0.0.
nvme_offload_dir (Optional[str], optional): Directory to save NVMe offload files.
@@ -75,10 +77,11 @@ class CPUAdam(NVMeOptimizer):
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
try:
import cpu_adam
import colossalai._C.cpu_optim
except ImportError:
raise ImportError('Please install colossalai from source code to use CPUAdam')
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
adamw_mode)
def torch_adam_update(self,
data,