mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[kernel] move all symlinks of kernel to colossalai._C
(#1971)
This commit is contained in:
@@ -3,14 +3,11 @@
|
||||
with some changes. """
|
||||
|
||||
import numbers
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn import init
|
||||
from torch.cuda.amp import custom_fwd, custom_bwd
|
||||
import importlib
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
colossal_layer_norm_cuda = None
|
||||
import torch
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from torch.nn import init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
@@ -18,13 +15,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, input, weight, bias, normalized_shape, eps):
|
||||
try:
|
||||
import colossalai._C.layer_norm
|
||||
except ImportError:
|
||||
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||
|
||||
ctx.normalized_shape = normalized_shape
|
||||
ctx.eps = eps
|
||||
input_ = input.contiguous()
|
||||
weight_ = weight.contiguous()
|
||||
bias_ = bias.contiguous()
|
||||
output, mean, invvar = colossal_layer_norm_cuda.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||
output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
|
||||
ctx.eps)
|
||||
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
|
||||
|
||||
@@ -33,11 +34,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_output):
|
||||
try:
|
||||
import colossalai._C.layer_norm
|
||||
except ImportError:
|
||||
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
|
||||
|
||||
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
|
||||
grad_input = grad_weight = grad_bias = None
|
||||
grad_input, grad_weight, grad_bias \
|
||||
= colossal_layer_norm_cuda.backward_affine(
|
||||
= colossalai._C.layer_norm.backward_affine(
|
||||
grad_output.contiguous(), mean, invvar,
|
||||
input_, ctx.normalized_shape,
|
||||
weight_, bias_, ctx.eps)
|
||||
@@ -50,13 +55,6 @@ class MixedFusedLayerNorm(torch.nn.Module):
|
||||
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
|
||||
super(MixedFusedLayerNorm, self).__init__()
|
||||
|
||||
global colossal_layer_norm_cuda
|
||||
if colossal_layer_norm_cuda is None:
|
||||
try:
|
||||
colossal_layer_norm_cuda = importlib.import_module("colossal_layer_norm_cuda")
|
||||
except ImportError:
|
||||
raise RuntimeError('MixedFusedLayerNorm requires cuda extensions')
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import math
|
||||
import importlib
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
@@ -37,21 +36,21 @@ colossal_multihead_attention = None
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 presion
|
||||
max_batch_tokens: int # max batch token numbers
|
||||
max_seq_len: int # max sequence length
|
||||
hidden_size: int # size of transformer hidden layers
|
||||
nhead: int # number of heads in attention
|
||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||
hidden_dropout_ratio: float # dropout ration before residual
|
||||
norm_first: bool # norm_first
|
||||
fp16: bool # fp16 presion
|
||||
|
||||
|
||||
class MultiHeadAttention1DFunc(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias, config):
|
||||
def forward(ctx, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight, out_proj_bias, norm_weight,
|
||||
norm_bias, config):
|
||||
cuda_module = colossal_multihead_attention
|
||||
forward_func = (cuda_module.multihead_attention_fw_fp16
|
||||
if config.fp16 else cuda_module.multihead_attention_fw_fp32)
|
||||
@@ -59,13 +58,12 @@ class MultiHeadAttention1DFunc(Function):
|
||||
input = input.to(torch.half)
|
||||
input_mask = input_mask.to(torch.half)
|
||||
|
||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias,
|
||||
config.training, config.norm_first)
|
||||
(output,) = forward_func(config.layer_id, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias, config.training, config.norm_first)
|
||||
|
||||
if config.is_grad_enabled and config.training:
|
||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias,
|
||||
out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
ctx.save_for_backward(output, input, input_mask, in_proj_weight, in_proj_bias, out_proj_weight,
|
||||
out_proj_bias, norm_weight, norm_bias)
|
||||
ctx.config = config
|
||||
return output
|
||||
|
||||
@@ -98,8 +96,8 @@ class MultiHeadAttention1DFunc(Function):
|
||||
ctx.config.layer_id, grad_output, output, input, input_mask, in_proj_weight,
|
||||
in_proj_bias, out_proj_weight, out_proj_bias, norm_weight, norm_bias)
|
||||
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight,
|
||||
grad_out_proj_bias, grad_norm_weight, grad_norm_bias, None)
|
||||
return (grad_input, None, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias,
|
||||
grad_norm_weight, grad_norm_bias, None)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
@@ -121,19 +119,11 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
layer_id = 0
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
nhead,
|
||||
batch_size,
|
||||
max_seq_len,
|
||||
dropout=0.0,
|
||||
norm_first=False,
|
||||
fp16=True,
|
||||
pg=None):
|
||||
def __init__(self, hidden_size, nhead, batch_size, max_seq_len, dropout=0.0, norm_first=False, fp16=True, pg=None):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout,
|
||||
dropout, norm_first, fp16)
|
||||
self.config = Config(batch_size * max_seq_len, max_seq_len, hidden_size, nhead, dropout, dropout, norm_first,
|
||||
fp16)
|
||||
check_config(self.config)
|
||||
self.pg = pg
|
||||
self.pg_size = 1
|
||||
@@ -146,7 +136,8 @@ class MultiHeadAttention(nn.Module):
|
||||
global colossal_multihead_attention
|
||||
if colossal_multihead_attention is None:
|
||||
try:
|
||||
colossal_multihead_attention = importlib.import_module("colossal_multihead_attention")
|
||||
import colossalai._C.multihead_attention
|
||||
colossal_multihead_attention = colossalai._C.multihead_attention
|
||||
except ImportError:
|
||||
raise RuntimeError('MultiHeadAttention requires cuda extensions')
|
||||
|
||||
@@ -215,14 +206,13 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
self.in_proj_weight.copy_(
|
||||
attn_qkvw_global.view(3, hs, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size),
|
||||
:])
|
||||
attn_qkvw_global.view(3, hs, hs)[:,
|
||||
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size), :])
|
||||
self.in_proj_bias.copy_(
|
||||
attn_qkvb_global.view(3, hs)[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
attn_qkvb_global.view(3, hs)[:,
|
||||
int(hs * rank_in_pg / self.pg_size):int(hs * (rank_in_pg + 1) /
|
||||
self.pg_size)])
|
||||
|
||||
attn_ow_global = torch.empty(hs, hs)
|
||||
nn.init.xavier_uniform_(attn_ow_global, 1.0)
|
||||
@@ -230,9 +220,9 @@ class MultiHeadAttention(nn.Module):
|
||||
torch.distributed.broadcast(attn_ow_global, src=0, group=self.pg)
|
||||
attn_ow_global = attn_ow_global.cpu()
|
||||
with torch.no_grad():
|
||||
self.out_proj_weight.copy_(attn_ow_global[
|
||||
:, int(hs * rank_in_pg / self.pg_size):
|
||||
int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
self.out_proj_weight.copy_(attn_ow_global[:,
|
||||
int(hs * rank_in_pg /
|
||||
self.pg_size):int(hs * (rank_in_pg + 1) / self.pg_size)])
|
||||
|
||||
else:
|
||||
attn_qkvw = self.in_proj_weight.view(-1, hs)
|
||||
@@ -243,10 +233,7 @@ class MultiHeadAttention(nn.Module):
|
||||
nn.init.xavier_uniform_(self.out_proj_weight, 1.0)
|
||||
|
||||
def state_dict(self, destination=None, prefix="", keep_vars=False):
|
||||
destination = torch.nn.Module.state_dict(self,
|
||||
destination=destination,
|
||||
prefix=prefix,
|
||||
keep_vars=keep_vars)
|
||||
destination = torch.nn.Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
return destination
|
||||
|
||||
def forward(self, hidden_states, encoder_padding_mask):
|
||||
@@ -257,8 +244,7 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
bs, sl, dim = hidden_states.size()
|
||||
if bs * sl > self.config.max_batch_tokens:
|
||||
raise ValueError(
|
||||
f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
raise ValueError(f"Batch token numbers {bs * sl} exceeds the limit {self.config.max_batch_tokens}.")
|
||||
if sl > self.config.max_seq_len:
|
||||
raise ValueError(f"Sequence length {sl} exceeds the limit {self.config.max_seq_len}.")
|
||||
if len(encoder_padding_mask.size()) == 1:
|
||||
@@ -266,9 +252,8 @@ class MultiHeadAttention(nn.Module):
|
||||
else:
|
||||
assert bs == encoder_padding_mask.size(0) and sl == encoder_padding_mask.size(1)
|
||||
|
||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.out_proj_weight, self.out_proj_bias,
|
||||
output = MultiHeadAttention1DFunc.apply(hidden_states, encoder_padding_mask, self.in_proj_weight,
|
||||
self.in_proj_bias, self.out_proj_weight, self.out_proj_bias,
|
||||
self.norm_weight, self.norm_bias, self.config)
|
||||
|
||||
return output.to(self.precision)
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""This code from NVIDIA Megatron
|
||||
with some changes. """
|
||||
|
||||
import enum
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import enum
|
||||
|
||||
|
||||
class AttnMaskType(enum.Enum):
|
||||
@@ -23,12 +24,12 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, scale):
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
softmax_results = colossal_scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
softmax_results = colossalai._C.scaled_upper_triang_masked_softmax.forward(inputs, scale_t[0])
|
||||
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
@@ -36,12 +37,13 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
import colossal_scaled_upper_triang_masked_softmax
|
||||
import colossalai._C.scaled_upper_triang_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledUpperTriangMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
input_grads = colossal_scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossalai._C.scaled_upper_triang_masked_softmax.backward(output_grads, softmax_results,
|
||||
scale_t[0])
|
||||
|
||||
return input_grads, None
|
||||
|
||||
@@ -58,26 +60,26 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, mask, scale):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
scale_t = torch.tensor([scale])
|
||||
|
||||
softmax_results = colossal_scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
|
||||
ctx.save_for_backward(softmax_results, scale_t)
|
||||
return softmax_results
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, output_grads):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
softmax_results, scale_t = ctx.saved_tensors
|
||||
|
||||
input_grads = colossal_scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
|
||||
return input_grads, None, None
|
||||
|
||||
|
||||
@@ -184,8 +186,8 @@ class FusedScaleMaskSoftmax(nn.Module):
|
||||
@staticmethod
|
||||
def get_batch_per_block(sq, sk, b, np):
|
||||
try:
|
||||
import colossal_scaled_masked_softmax
|
||||
import colossalai._C.scaled_masked_softmax
|
||||
except ImportError:
|
||||
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
|
||||
|
||||
return colossal_scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
return colossalai._C.scaled_masked_softmax.get_batch_per_block(sq, sk, b, np)
|
||||
|
Reference in New Issue
Block a user