[setup] support pre-build and jit-build of cuda kernels (#2374)

* [setup] support pre-build and jit-build of cuda kernels

* polish code

* polish code

* polish code

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-01-06 20:50:26 +08:00
committed by GitHub
parent 12c8bf38d7
commit 40d376c566
36 changed files with 414 additions and 390 deletions

View File

View File

@@ -1,9 +0,0 @@
from . import (
cpu_optim,
fused_optim,
layer_norm,
moe,
multihead_attention,
scaled_masked_softmax,
scaled_upper_triang_masked_softmax,
)

View File

@@ -1,8 +0,0 @@
from torch import Tensor
class CPUAdamOptimizer:
def __init__(self, lr: float, beta1: float, beta2: float, eps: float,
weight_decay: float, adamw_mode: float) -> None: ...
def step(self, step: int, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, bias_correction: bool,
param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor, loss_scale: float) -> None: ...

View File

@@ -1,23 +0,0 @@
from typing import List
from torch import Tensor
def multi_tensor_scale(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], scale: float) -> None:
...
def multi_tensor_sgd(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], weight_decay: float,
momentum: float, dampening: float, lr: float, nesterov: bool, first_run: bool, weight_decay_after_momentum: bool, scale: float) -> None:
...
def multi_tensor_adam(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, mode: int, bias_correction: int, weight_decay: float, div_scale: float) -> None:
...
def multi_tensor_lamb(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], lr: float, beta1: float, beta2: float, epsilon: float, step: int, bias_correction: int, weight_decay: float, grad_averaging: int, mode: int, global_grad_norm: Tensor, max_grad_norm: float, use_nvlamb_python: bool) -> None:
...
def multi_tensor_l2norm(chunk_size: int, noop_flag: Tensor, tensor_lists: List[List[Tensor]], per_tensor_python: bool) -> None:
...

View File

@@ -1,11 +0,0 @@
from typing import List
from torch import Tensor
def forward_affine(input: Tensor, normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...
def backward_affine(dout: Tensor, mean: Tensor, invvar: Tensor, input: Tensor,
normalized_shape: List[int], gamma: Tensor, beta: Tensor, epsilon: float) -> List[Tensor]:
...

View File

@@ -1,20 +0,0 @@
from torch import Tensor
def cumsum_sub_one(mask: Tensor) -> Tensor:
...
def dispatch_forward(s: int, ec: int, h: int, batch_tokens: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def dispatch_backward(s: int, ec: int, h: int, expert_grad: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_forward(s: int, e: int, c: int, h: int, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...
def combine_backward(s: int, e: int, c: int, h: int, tokens_grad: Tensor, expert_tokens: Tensor, logits: Tensor, mask: Tensor, dest_idx: Tensor) -> Tensor:
...

View File

@@ -1,55 +0,0 @@
from typing import List
from torch import Tensor
from torch.distributed import ProcessGroup
def multihead_attention_fw_fp32(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_fw_fp16(layer_id: int, input: Tensor, input_mask: Tensor,
in_proj_weight: Tensor, in_proj_bias: Tensor,
out_proj_weight: Tensor, out_proj_bias: Tensor,
norm_weight: Tensor, norm_bias: Tensor,
training_mode: bool, prelayernorm: bool) -> List[Tensor]:
...
def multihead_attention_bw_fp32(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def multihead_attention_bw_fp16(layer_id: int, grad_dec_output: Tensor,
output: Tensor, input: Tensor,
input_mask: Tensor, in_proj_weight: Tensor,
in_proj_bias: Tensor, out_proj_weight: Tensor,
out_proj_bias: Tensor, norm_weight: Tensor,
norm_bias: Tensor) -> List[Tensor]:
...
def create_multihead_attention_fp32(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...
def create_multihead_attention_fp16(layer_id: int, max_batch_tokens: int,
max_seq_len: int, hidden_dim: int, num_heads: int,
attn_prob_dropout_ratio: float,
hidden_dropout_ratio: float,
pre_or_postLayerNorm: bool,
pg: ProcessGroup) -> int:
...

View File

@@ -1,12 +0,0 @@
from torch import Tensor
def forward(input: Tensor, mask: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...
def get_batch_per_block(query_seq_len: int, key_seq_len: int, batches: int, attn_heads: int) -> int:
...

View File

@@ -1,8 +0,0 @@
from torch import Tensor
def forward(input: Tensor, scale: float) -> Tensor:
...
def backward(output_grads: Tensor, softmax_results: Tensor, scale: float) -> Tensor:
...