mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 05:01:44 +00:00
[setup] support pre-build and jit-build of cuda kernels (#2374)
* [setup] support pre-build and jit-build of cuda kernels * polish code * polish code * polish code * polish code * polish code * polish code
This commit is contained in:
0
colossalai/_C/__init__.py
Normal file
0
colossalai/_C/__init__.py
Normal file
@@ -1,9 +0,0 @@
|
||||
from . import (
|
||||
cpu_optim,
|
||||
fused_optim,
|
||||
layer_norm,
|
||||
moe,
|
||||
multihead_attention,
|
||||
scaled_masked_softmax,
|
||||
scaled_upper_triang_masked_softmax,
|
||||
)
|
@@ -1,8 +0,0 @@
|
||||
from torch import Tensor
|
||||
|
||||
class CPUAdamOptimizer:
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
|
||||
weight_decay: float, adamw_mode: float) -> None: ...
|
||||
|
||||
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
|
||||
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...
|
@@ -1,23 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
|
||||
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
|
||||
...
|
||||
|
||||
|
||||
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
|
||||
...
|
@@ -1,11 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
|
||||
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
|
||||
...
|
@@ -1,20 +0,0 @@
|
||||
from torch import Tensor
|
||||
|
||||
def cumsum_sub_one(mask: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
|
||||
...
|
@@ -1,55 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||
norm_weight: Tensor, norm_bias: Tensor,
|
||||
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
|
||||
in_proj_weight: Tensor, in_proj_bias: Tensor,
|
||||
out_proj_weight: Tensor, out_proj_bias: Tensor,
|
||||
norm_weight: Tensor, norm_bias: Tensor,
|
||||
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
|
||||
output: Tensor, input: Tensor,
|
||||
input_mask: Tensor, in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||
norm_bias: Tensor) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
|
||||
output: Tensor, input: Tensor,
|
||||
input_mask: Tensor, in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor, out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor, norm_weight: Tensor,
|
||||
norm_bias: Tensor) -> List[Tensor]:
|
||||
...
|
||||
|
||||
|
||||
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
|
||||
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||
attn_prob_dropout_ratio: float,
|
||||
hidden_dropout_ratio: float,
|
||||
pre_or_postLayerNorm: bool,
|
||||
pg: ProcessGroup) -> int:
|
||||
...
|
||||
|
||||
|
||||
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
|
||||
max_seq_len: int, hidden_dim: int, num_heads: int,
|
||||
attn_prob_dropout_ratio: float,
|
||||
hidden_dropout_ratio: float,
|
||||
pre_or_postLayerNorm: bool,
|
||||
pg: ProcessGroup) -> int:
|
||||
...
|
@@ -1,12 +0,0 @@
|
||||
from torch import Tensor
|
||||
|
||||
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
|
||||
...
|
@@ -1,8 +0,0 @@
|
||||
from torch import Tensor
|
||||
|
||||
def forward(input: Tensor, scale: float) -> Tensor:
|
||||
...
|
||||
|
||||
|
||||
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
|
||||
...
|
@@ -8,16 +8,28 @@ from torch.optim import Optimizer
|
||||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier
|
||||
|
||||
from ._utils import has_inf_or_nan, zero_gard_by_list
|
||||
from .grad_scaler import BaseGradScaler
|
||||
|
||||
try:
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
fused_optim = None
|
||||
|
||||
__all__ = ['FP16Optimizer']
|
||||
|
||||
|
||||
def load_fused_optim():
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
|
||||
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
"""
|
||||
adapted from Megatron-LM (https://github.com/NVIDIA/Megatron-LM)
|
||||
@@ -30,6 +42,8 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
||||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
global fused_optim
|
||||
load_fused_optim()
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
|
@@ -1,42 +1,7 @@
|
||||
from .cuda_native import FusedScaleMaskSoftmax, LayerNorm, MultiHeadAttention
|
||||
|
||||
try:
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
try:
|
||||
from colossalai._C import cpu_optim
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
try:
|
||||
from colossalai._C import multihead_attention
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
|
||||
multihead_attention = MultiHeadAttnBuilder().load()
|
||||
|
||||
try:
|
||||
from colossalai._C import scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder import ScaledSoftmaxBuilder
|
||||
scaled_upper_triang_masked_softmax = ScaledSoftmaxBuilder().load()
|
||||
|
||||
try:
|
||||
from colossalai._C import moe
|
||||
except ImportError:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
__all__ = [
|
||||
"fused_optim",
|
||||
"cpu_optim",
|
||||
"multihead_attention",
|
||||
"moe",
|
||||
"LayerNorm",
|
||||
"FusedScaleMaskSoftmax",
|
||||
"MultiHeadAttention",
|
||||
"scaled_upper_triang_masked_softmax",
|
||||
]
|
||||
|
@@ -135,7 +135,8 @@ class MultiHeadAttention(nn.Module):
|
||||
# Load cuda modules if needed
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
from colossalai.kernel import multihead_attention
|
||||
from colossalai.kernel.op_builder import MultiHeadAttnBuilder
|
||||
multihead_attention = MultiHeadAttnBuilder().load()
|
||||
colossal_multihead_attention = multihead_attention
|
||||
|
||||
# create the layer in cuda kernels.
|
||||
|
@@ -6,13 +6,32 @@ from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
COL_MOE_KERNEL_FLAG = False
|
||||
from colossalai.kernel import moe
|
||||
|
||||
try:
|
||||
from colossalai._C import moe
|
||||
except:
|
||||
moe = None
|
||||
|
||||
|
||||
def build_moe_if_not_prebuilt():
|
||||
# load moe kernel during runtime if not pre-built
|
||||
global moe
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
|
||||
class AllGather(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any, inputs: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
|
||||
global moe
|
||||
|
||||
if moe is None:
|
||||
from colossalai.kernel.op_builder import MOEBuilder
|
||||
moe = MOEBuilder().load()
|
||||
|
||||
if ctx is not None:
|
||||
ctx.comm_grp = group
|
||||
|
||||
@@ -85,6 +104,9 @@ class MoeDispatch(torch.autograd.Function):
|
||||
s = tokens.size(0)
|
||||
h = tokens.size(1)
|
||||
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
|
||||
expert_input = moe.dispatch_forward(s, ec, h, tokens, mask, dest_idx)
|
||||
|
||||
ctx.save_for_backward(mask, dest_idx)
|
||||
@@ -112,6 +134,9 @@ class MoeCombine(torch.autograd.Function):
|
||||
c = ec // e
|
||||
h = expert_tokens.size(-1)
|
||||
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = moe.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
@@ -143,6 +168,8 @@ def moe_cumsum(inputs: Tensor):
|
||||
dim0 = inputs.size(0)
|
||||
flag = (dim0 <= 1024) or (dim0 <= 2048 and dim0 % 2 == 0) or (dim0 % 4 == 0)
|
||||
if flag and COL_MOE_KERNEL_FLAG:
|
||||
# load moe kernel during runtime if not pre-built
|
||||
build_moe_if_not_prebuilt()
|
||||
return moe.cumsum_sub_one(inputs)
|
||||
else:
|
||||
return torch.cumsum(inputs, dim=0) - 1
|
||||
|
@@ -3,6 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
@@ -76,12 +77,8 @@ class CPUAdam(NVMeOptimizer):
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(CPUAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import colossalai._C.cpu_optim
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use CPUAdam')
|
||||
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||
adamw_mode)
|
||||
cpu_adam = CPUAdamBuilder().load()
|
||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
def torch_adam_update(self,
|
||||
data,
|
||||
|
@@ -65,7 +65,8 @@ class FusedAdam(torch.optim.Optimizer):
|
||||
self.adamw_mode = 1 if adamw_mode else 0
|
||||
self.set_grad_none = set_grad_none
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
@@ -76,7 +76,8 @@ class FusedLAMB(torch.optim.Optimizer):
|
||||
max_grad_norm=max_grad_norm)
|
||||
super(FusedLAMB, self).__init__(params, defaults)
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm
|
||||
# Skip buffer
|
||||
|
@@ -80,7 +80,8 @@ class FusedSGD(Optimizer):
|
||||
self.wd_after_momentum = wd_after_momentum
|
||||
|
||||
if multi_tensor_applier.available:
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
# Skip buffer
|
||||
self._dummy_overflow_buf = torch.tensor([0],
|
||||
|
@@ -2,6 +2,7 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
@@ -77,7 +78,9 @@ class HybridAdam(NVMeOptimizer):
|
||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
|
||||
from colossalai.kernel import cpu_optim, fused_optim
|
||||
# build during runtime if not found
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
|
@@ -18,11 +18,15 @@ from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARA
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
from colossalai.kernel import fused_optim
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
|
||||
from .multi_tensor_apply import multi_tensor_applier
|
||||
|
||||
try:
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
fused_optim = None
|
||||
|
||||
|
||||
def print_rank_0(msg: str, logger=None):
|
||||
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
|
||||
@@ -123,6 +127,13 @@ def is_model_parallel_parameter(p):
|
||||
|
||||
|
||||
def _calc_l2_norm(grads):
|
||||
# we should not
|
||||
global fused_optim
|
||||
|
||||
if fused_optim is None:
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
norm = 0.0
|
||||
if len(grads) > 0:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
@@ -14,7 +14,6 @@ class MultiTensorApply(object):
|
||||
|
||||
def __init__(self, chunk_size):
|
||||
try:
|
||||
from colossalai.kernel import fused_optim
|
||||
MultiTensorApply.available = True
|
||||
self.chunk_size = chunk_size
|
||||
except ImportError as err:
|
||||
|
Reference in New Issue
Block a user