mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-01 13:15:26 +00:00
Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph
This commit is contained in:
commit
d02e257abd
@ -9,6 +9,7 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
LlamaForCausalLM,
|
LlamaForCausalLM,
|
||||||
LlamaMLP,
|
LlamaMLP,
|
||||||
LlamaModel,
|
LlamaModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
from colossalai.inference.config import InputMetaData
|
from colossalai.inference.config import InputMetaData
|
||||||
@ -19,6 +20,7 @@ from colossalai.kernel.triton import (
|
|||||||
decoding_fused_rotary_embedding,
|
decoding_fused_rotary_embedding,
|
||||||
flash_decoding_attention,
|
flash_decoding_attention,
|
||||||
get_xine_cache,
|
get_xine_cache,
|
||||||
|
rms_layernorm,
|
||||||
rotary_embedding,
|
rotary_embedding,
|
||||||
)
|
)
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
@ -121,7 +123,8 @@ def llama_model_forward(
|
|||||||
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
hidden_states = hidden_states[last_token_indexs - 1].contiguous()
|
||||||
residual = residual[last_token_indexs - 1].contiguous()
|
residual = residual[last_token_indexs - 1].contiguous()
|
||||||
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
norm_output = torch.empty_like(hidden_states) # NOTE non-functional, but cuda graph only capture decoding only
|
||||||
hidden_states, _ = self.norm(hidden_states, norm_output, residual)
|
hidden_states, _ = self.norm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
|
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@ -164,7 +167,7 @@ def llama_decoder_layer_forward(
|
|||||||
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
# Self Attention
|
# Self Attention
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@ -182,12 +185,32 @@ def llama_decoder_layer_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, norm_output, residual, use_cuda_kernel)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
def llama_rmsnorm_forward(
|
||||||
|
self: LlamaRMSNorm,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
norm_output: torch.Tensor,
|
||||||
|
residual: torch.Tensor = None,
|
||||||
|
use_cuda_kernel: bool = True,
|
||||||
|
):
|
||||||
|
if use_cuda_kernel:
|
||||||
|
if residual is not None:
|
||||||
|
inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, self.variance_epsilon)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
if norm_output is None:
|
||||||
|
norm_output = torch.empty_like(hidden_states)
|
||||||
|
inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, self.variance_epsilon)
|
||||||
|
return norm_output, hidden_states
|
||||||
|
else:
|
||||||
|
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||||
|
|
||||||
|
|
||||||
class NopadLlamaAttention(LlamaAttention):
|
class NopadLlamaAttention(LlamaAttention):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -295,8 +318,12 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
|
|
||||||
block_size = k_cache.size(-2)
|
block_size = k_cache.size(-2)
|
||||||
|
|
||||||
if is_prompts:
|
if is_prompts:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
if use_cuda_kernel:
|
||||||
|
inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
|
else:
|
||||||
|
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
||||||
attn_output = context_attention_unpadded(
|
attn_output = context_attention_unpadded(
|
||||||
q=query_states,
|
q=query_states,
|
||||||
k=key_states,
|
k=key_states,
|
||||||
@ -312,9 +339,16 @@ class NopadLlamaAttention(LlamaAttention):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if use_cuda_kernel:
|
if use_cuda_kernel:
|
||||||
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
|
inference_ops.rotary_embedding_and_cache_copy(
|
||||||
inference_ops.decode_kv_cache_memcpy(
|
query_states,
|
||||||
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
|
key_states,
|
||||||
|
value_states,
|
||||||
|
cos_sin[0],
|
||||||
|
cos_sin[1],
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
decoding_fused_rotary_embedding(
|
decoding_fused_rotary_embedding(
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||||
|
|
||||||
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
|||||||
llama_causal_lm_forward,
|
llama_causal_lm_forward,
|
||||||
llama_decoder_layer_forward,
|
llama_decoder_layer_forward,
|
||||||
llama_model_forward,
|
llama_model_forward,
|
||||||
|
llama_rmsnorm_forward,
|
||||||
)
|
)
|
||||||
from colossalai.inference.utils import init_to_get_rotary
|
from colossalai.inference.utils import init_to_get_rotary
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
@ -17,27 +17,6 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription,
|
|||||||
# import colossalai
|
# import colossalai
|
||||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||||
|
|
||||||
try:
|
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
|
||||||
|
|
||||||
HAS_TRITON_RMSNORM = True
|
|
||||||
except:
|
|
||||||
print("you should install triton from https://github.com/openai/triton")
|
|
||||||
HAS_TRITON_RMSNORM = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_triton_rmsnorm_forward():
|
|
||||||
if HAS_TRITON_RMSNORM:
|
|
||||||
|
|
||||||
def _triton_rmsnorm_forward(
|
|
||||||
self: LlamaRMSNorm, hidden_states: torch.Tensor, norm_output: torch.Tensor, residual: torch.Tensor = None
|
|
||||||
):
|
|
||||||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
|
||||||
|
|
||||||
return _triton_rmsnorm_forward
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -84,15 +63,9 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||||||
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
description=method_replacement, policy=policy, target_key=LlamaDecoderLayer
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_forward = None
|
infer_forward = llama_rmsnorm_forward
|
||||||
if HAS_TRITON_RMSNORM:
|
method_replacement = {"forward": partial(infer_forward)}
|
||||||
infer_forward = get_triton_rmsnorm_forward()
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaRMSNorm)
|
||||||
|
|
||||||
if infer_forward is not None:
|
|
||||||
method_replacement = {"forward": partial(infer_forward)}
|
|
||||||
self.append_or_create_method_replacement(
|
|
||||||
description=method_replacement, policy=policy, target_key=LlamaRMSNorm
|
|
||||||
)
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -47,5 +47,5 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
|||||||
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
|
|
||||||
self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
|
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
|
||||||
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
|
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
|
||||||
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
|
|
||||||
@ -16,9 +19,19 @@ configs = [
|
|||||||
x_names=["num_tokens"],
|
x_names=["num_tokens"],
|
||||||
x_vals=[2**i for i in range(4, 11)],
|
x_vals=[2**i for i in range(4, 11)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
line_vals=[
|
||||||
line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"],
|
"no_fused_triton_rotary_emb_func",
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
"fused_triton_rotary_emb_func",
|
||||||
|
"no_fused_cuda_rotary_emb_func",
|
||||||
|
"fused_cuda_rotary_emb_func",
|
||||||
|
],
|
||||||
|
line_names=[
|
||||||
|
"no_fused_triton_rotary_emb_func",
|
||||||
|
"fused_triton_rotary_emb_func",
|
||||||
|
"no_fused_cuda_rotary_emb_func",
|
||||||
|
"fused_cuda_rotary_emb_func",
|
||||||
|
],
|
||||||
|
styles=[("red", "-"), ("blue", "-"), ("green", "-"), ("yellow", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
args={"num_kv_heads": 16},
|
args={"num_kv_heads": 16},
|
||||||
@ -32,7 +45,7 @@ def benchmark_rotary_emb(
|
|||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
):
|
):
|
||||||
BATCH_SIZE = 4
|
BATCH_SIZE = 16
|
||||||
SEQ_LEN = num_tokens // BATCH_SIZE
|
SEQ_LEN = num_tokens // BATCH_SIZE
|
||||||
max_num_blocks_per_seq = 8
|
max_num_blocks_per_seq = 8
|
||||||
block_size = 64
|
block_size = 64
|
||||||
@ -68,7 +81,7 @@ def benchmark_rotary_emb(
|
|||||||
kv_seq_lengths = past_kv_seq_lengths + 1
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
block_tables = block_tables.to(device="cuda")
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
|
||||||
if provider == "no_fused_rotary_emb_func":
|
if provider == "no_fused_triton_rotary_emb_func":
|
||||||
fn = lambda: [
|
fn = lambda: [
|
||||||
rotary_embedding(new_q, new_k, cos, sin),
|
rotary_embedding(new_q, new_k, cos, sin),
|
||||||
copy_kv_to_blocked_cache(
|
copy_kv_to_blocked_cache(
|
||||||
@ -77,7 +90,16 @@ def benchmark_rotary_emb(
|
|||||||
]
|
]
|
||||||
elif provider == "fused_triton_rotary_emb_func":
|
elif provider == "fused_triton_rotary_emb_func":
|
||||||
fn = lambda: decoding_fused_rotary_embedding(
|
fn = lambda: decoding_fused_rotary_embedding(
|
||||||
new_q, new_k, new_k, cos, sin, k_cache, k_cache, block_tables, kv_seq_lengths
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, block_tables, kv_seq_lengths
|
||||||
|
)
|
||||||
|
elif provider == "no_fused_cuda_rotary_emb_func":
|
||||||
|
fn = lambda: [
|
||||||
|
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
|
||||||
|
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
|
||||||
|
]
|
||||||
|
elif provider == "fused_cuda_rotary_emb_func":
|
||||||
|
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
|
||||||
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider")
|
raise ValueError("Undefined provider")
|
@ -1,14 +1,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
from colossalai.kernel.triton import rms_layernorm
|
from colossalai.kernel.triton import rms_layernorm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import triton # noqa
|
import triton # noqa
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
print("please install triton from https://github.com/openai/triton")
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
# Triton benchmark plot attributions
|
# Triton benchmark plot attributions
|
||||||
configs = [
|
configs = [
|
||||||
@ -19,16 +19,20 @@ configs = [
|
|||||||
line_vals=[
|
line_vals=[
|
||||||
"vllm_rms_layernorm",
|
"vllm_rms_layernorm",
|
||||||
"triton_rms_layernorm",
|
"triton_rms_layernorm",
|
||||||
"triton_rms_layernorm_with_residual",
|
"cuda_rms_layernorm",
|
||||||
"vllm_rms_layernorm_with_residual",
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"cuda_rms_layernorm_with_residual",
|
||||||
],
|
],
|
||||||
line_names=[
|
line_names=[
|
||||||
"vllm_rms_layernorm",
|
"vllm_rms_layernorm",
|
||||||
"triton_rms_layernorm",
|
"triton_rms_layernorm",
|
||||||
"triton_rms_layernorm_with_residual",
|
"cuda_rms_layernorm",
|
||||||
"vllm_rms_layernorm_with_residual",
|
"vllm_rms_layernorm_with_residual",
|
||||||
|
"triton_rms_layernorm_with_residual",
|
||||||
|
"cuda_rms_layernorm_with_residual",
|
||||||
],
|
],
|
||||||
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("green", "-")],
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-"), ("red", "--"), ("blue", "--"), ("yellow", "--")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"RMSNorm benchmarking results",
|
plot_name=f"RMSNorm benchmarking results",
|
||||||
args={"HIDDEN_SIZE": 1024},
|
args={"HIDDEN_SIZE": 1024},
|
||||||
@ -62,10 +66,15 @@ def benchmark_rms_layernorm(
|
|||||||
fn = lambda: vllm_norm(x)
|
fn = lambda: vllm_norm(x)
|
||||||
elif provider == "triton_rms_layernorm":
|
elif provider == "triton_rms_layernorm":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
fn = lambda: rms_layernorm(x, weight, eps=eps)
|
||||||
|
elif provider == "cuda_rms_layernorm":
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
fn = lambda: inference_ops.rms_layernorm(out, x, weight, eps)
|
||||||
elif provider == "vllm_rms_layernorm_with_residual":
|
elif provider == "vllm_rms_layernorm_with_residual":
|
||||||
fn = lambda: vllm_norm(x, residual=residual)
|
fn = lambda: vllm_norm(x, residual=residual)
|
||||||
elif provider == "triton_rms_layernorm_with_residual":
|
elif provider == "triton_rms_layernorm_with_residual":
|
||||||
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
fn = lambda: rms_layernorm(x, weight, eps=eps, residual=residual)
|
||||||
|
elif provider == "cuda_rms_layernorm_with_residual":
|
||||||
|
fn = lambda: inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider.")
|
raise ValueError("Undefined provider.")
|
||||||
|
|
@ -1,7 +1,11 @@
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
from vllm._C import ops
|
||||||
|
|
||||||
from colossalai.kernel.triton.fused_rotary_embedding import fused_rotary_embedding
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.kernel.triton import rotary_embedding
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
BATCH = 16
|
BATCH = 16
|
||||||
configs = [
|
configs = [
|
||||||
@ -9,9 +13,9 @@ configs = [
|
|||||||
x_names=["num_tokens"],
|
x_names=["num_tokens"],
|
||||||
x_vals=[2**i for i in range(4, 12)],
|
x_vals=[2**i for i in range(4, 12)],
|
||||||
line_arg="provider",
|
line_arg="provider",
|
||||||
line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_vals=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
|
||||||
line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"],
|
line_names=["triton_func", "colossal_cuda_func", "vllm_cuda_func"],
|
||||||
styles=[("red", "-"), ("blue", "-")],
|
styles=[("red", "-"), ("blue", "-"), ("yellow", "-")],
|
||||||
ylabel="ms",
|
ylabel="ms",
|
||||||
plot_name=f"rotary_emb-batch-{BATCH}",
|
plot_name=f"rotary_emb-batch-{BATCH}",
|
||||||
args={"num_kv_heads": 16},
|
args={"num_kv_heads": 16},
|
||||||
@ -48,12 +52,19 @@ def benchmark_rotary_emb(
|
|||||||
cos_shape = (4096, head_dim // 2)
|
cos_shape = (4096, head_dim // 2)
|
||||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
lengths = torch.tensor([3, 4, 6, 7], device="cuda")
|
|
||||||
|
|
||||||
if provider == "torch_rotary_emb_func":
|
cos_sin = torch.stack((cos, sin), dim=1).contiguous()
|
||||||
fn = lambda: torch_rotary_emb(q, cos[:num_tokens], sin[:num_tokens])
|
|
||||||
elif provider == "triton_rotary_emb_func":
|
positions = torch.arange(num_tokens).cuda()
|
||||||
fn = lambda: fused_rotary_embedding(q, k, cos, sin, lengths)
|
|
||||||
|
if provider == "triton_func":
|
||||||
|
fn = lambda: rotary_embedding(q, k, cos, sin)
|
||||||
|
elif provider == "colossal_cuda_func":
|
||||||
|
fn = lambda: inference_ops.rotary_embedding(q, k, cos, sin)
|
||||||
|
elif provider == "vllm_cuda_func":
|
||||||
|
q = q.view(num_tokens, -1)
|
||||||
|
k = k.view(num_tokens, -1)
|
||||||
|
fn = lambda: ops.rotary_embedding(positions, q, k, head_dim, cos_sin, True)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Undefined provider")
|
raise ValueError("Undefined provider")
|
||||||
|
|
54
examples/inference/benchmark_ops/benchmark_xine_copy.py
Normal file
54
examples/inference/benchmark_ops/benchmark_xine_copy.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.triton import get_xine_cache
|
||||||
|
from tests.test_infer.test_ops.triton.test_xine_copy import get_cos_sin
|
||||||
|
|
||||||
|
try:
|
||||||
|
import triton # noqa
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
print("please install triton from https://github.com/openai/triton")
|
||||||
|
|
||||||
|
|
||||||
|
configs = [
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["max_num_tokens"],
|
||||||
|
x_vals=[2**i for i in range(6, 12)],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||||
|
line_names=["torch_get_cos_sin", "triton_get_cos_sin"],
|
||||||
|
styles=[("red", "-"), ("blue", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name="Get_cos-sin_func",
|
||||||
|
args={"batch_size": 16, "head_dim": 256},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(configs)
|
||||||
|
def benchmark_get_xine_cache(
|
||||||
|
provider: str,
|
||||||
|
max_num_tokens: int,
|
||||||
|
batch_size: int,
|
||||||
|
head_dim: int,
|
||||||
|
):
|
||||||
|
warmup = 10
|
||||||
|
rep = 1000
|
||||||
|
dtype = torch.float16
|
||||||
|
cos_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||||
|
sin_cache = torch.randn((8912, head_dim), dtype=dtype, device="cuda")
|
||||||
|
lengths = torch.randint(2, max_num_tokens, (batch_size,), device="cuda")
|
||||||
|
|
||||||
|
if provider == "torch_get_cos_sin":
|
||||||
|
fn = lambda: get_cos_sin(lengths, cos_cache, sin_cache, is_prompts=True, dtype=dtype)
|
||||||
|
elif provider == "triton_get_cos_sin":
|
||||||
|
fn = lambda: get_xine_cache(lengths, cos_cache, sin_cache, is_prompts=True)
|
||||||
|
else:
|
||||||
|
raise ValueError("Undefined provider")
|
||||||
|
|
||||||
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
||||||
|
return ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_get_xine_cache.run(save_path=".", print_data=True)
|
@ -21,7 +21,7 @@ class CpuAdamX86Extension(_CudaExtension):
|
|||||||
# necessary 4 functions
|
# necessary 4 functions
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path("cuda/cpu_adam.cpp"),
|
self.csrc_abs_path("x86/cpu_adam.cpp"),
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
122
extensions/csrc/common/cuda_type_utils.h
Normal file
122
extensions/csrc/common/cuda_type_utils.h
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
/*
|
||||||
|
* This code from NVIDIA FasterTransformer:
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/utils/cuda_type_utils.cuh
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T add(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half2 add(half2 a, half2 b) {
|
||||||
|
return __hadd2(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half add(half a, half b) {
|
||||||
|
return __hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||||
|
return bf16hadd2(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||||
|
return bf16hadd(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline __device__ T mul(T a, T b, T c) {
|
||||||
|
return a * b * c;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline __device__ half2 mul(half2 a, half2 b, half2 c) {
|
||||||
|
return __hmul2(__hmul2(a, b), c);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b,
|
||||||
|
__nv_bfloat16 c) {
|
||||||
|
return bf16hmul(a, b, c);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||||
|
__nv_bfloat162 c) {
|
||||||
|
return bf16hmul2(a, b, c);
|
||||||
|
}
|
||||||
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
|
template <typename T_OUT, typename T_IN>
|
||||||
|
__device__ inline T_OUT cuda_cast(T_IN val) {
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, int2>(int2 val) {
|
||||||
|
return make_float2(val.x, val.y);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, float>(float val) {
|
||||||
|
return make_float2(val, val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float2 cuda_cast<float2, half2>(half2 val) {
|
||||||
|
return __half22float2(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, float2>(float2 val) {
|
||||||
|
return __float22half2_rn(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, float>(float val) {
|
||||||
|
return __float2half2_rn(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline half2 cuda_cast<half2, half>(half val) {
|
||||||
|
return __half2half2(val);
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__device__ inline float cuda_cast<float, half>(half val) {
|
||||||
|
return __half2float(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||||
|
template <typename T>
|
||||||
|
struct TypeConverter {
|
||||||
|
using Type = half2;
|
||||||
|
}; // keep for generality
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<half2> {
|
||||||
|
using Type = at::Half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<at::Half> {
|
||||||
|
using Type = half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if ENABLE_BF16
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<__nv_bfloat162> {
|
||||||
|
using Type = at::BFloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<at::BFloat16> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
#endif // ENABLE_BF16
|
20
extensions/csrc/common/dev_info_mgr.h
Normal file
20
extensions/csrc/common/dev_info_mgr.h
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "common/nvgpu_dev_info.h"
|
||||||
|
#include "target.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
template <typename Ret>
|
||||||
|
class DevInfoMgr final {
|
||||||
|
public:
|
||||||
|
static std::unique_ptr<Ret> GetDevInfo(int device_num) const {
|
||||||
|
return std::make_unique<Ret>(device_num);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace colossalAI
|
@ -4,9 +4,20 @@
|
|||||||
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
This file is adapted from fused adam in NVIDIA/apex, commit a109f85
|
||||||
Licensed under the MIT License.
|
Licensed under the MIT License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#ifndef TORCH_CHECK
|
||||||
|
#define TORCH_CHECK AT_CHECK
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef VERSION_GE_1_3
|
||||||
|
#define DATA_PTR data_ptr
|
||||||
|
#else
|
||||||
|
#define DATA_PTR data
|
||||||
|
#endif
|
||||||
|
|
||||||
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
|
||||||
switch (TYPE) { \
|
switch (TYPE) { \
|
||||||
@ -211,90 +222,3 @@
|
|||||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||||
"'"); \
|
"'"); \
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes(
|
|
||||||
T *x, T val, int lanes = 1,
|
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
|
||||||
int blockSize =
|
|
||||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
||||||
|
|
||||||
if (blockSize >= 64) {
|
|
||||||
x[tid] = val;
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
|
||||||
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
|
||||||
|
|
||||||
if (tid < 32) {
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = x[tid] + x[tid + 32];
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final = final + __shfl_down_sync(0xffffffff, final, i);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result) {
|
|
||||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
|
||||||
T *x, T val, int lanes = 1,
|
|
||||||
bool share_result = false) // lanes is intended to be <= 32.
|
|
||||||
{
|
|
||||||
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
|
||||||
int blockSize =
|
|
||||||
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
|
||||||
|
|
||||||
if (blockSize >= 64) {
|
|
||||||
x[tid] = val;
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
|
||||||
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
T final;
|
|
||||||
|
|
||||||
if (tid < 32) {
|
|
||||||
if (blockSize >= 64)
|
|
||||||
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
|
||||||
else
|
|
||||||
final = val;
|
|
||||||
// __SYNCWARP();
|
|
||||||
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 16; i >= lanes; i >>= 1)
|
|
||||||
final =
|
|
||||||
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (share_result) {
|
|
||||||
if (tid < lanes) x[tid] = final; // EpilogueOp
|
|
||||||
// Make sure the smem result is visible to all warps.
|
|
||||||
__syncthreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
return final;
|
|
||||||
}
|
|
35
extensions/csrc/common/mp_type_traits.h
Normal file
35
extensions/csrc/common/mp_type_traits.h
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
|
||||||
|
#include "micros.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class MPTypeTrait {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<float> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<at::Half> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
class MPTypeTrait<at::BFloat16> {
|
||||||
|
public:
|
||||||
|
using Type = float;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace colossalAI
|
134
extensions/csrc/common/target.h
Normal file
134
extensions/csrc/common/target.h
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <exception>
|
||||||
|
#include <iostream>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace common {
|
||||||
|
|
||||||
|
class Target {
|
||||||
|
public:
|
||||||
|
enum class OS : int {
|
||||||
|
Unk = -1,
|
||||||
|
Linux,
|
||||||
|
Windows,
|
||||||
|
};
|
||||||
|
enum class Arch : int {
|
||||||
|
Unk = -1,
|
||||||
|
X86,
|
||||||
|
Arm,
|
||||||
|
NVGPU,
|
||||||
|
AMDGPU,
|
||||||
|
Ascend,
|
||||||
|
};
|
||||||
|
enum class BitLen : int {
|
||||||
|
Unk = -1,
|
||||||
|
k32,
|
||||||
|
k64,
|
||||||
|
};
|
||||||
|
|
||||||
|
explicit Target(OS os, Arch arch, BitLen bitlen)
|
||||||
|
: os_(os), arch_(arch), bitlen_(bitlen) {}
|
||||||
|
|
||||||
|
bool defined() const {
|
||||||
|
return (os_ != OS::Unk) && (arch_ != Arch::Unk) && (bitlen_ != BitLen::Unk);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string str() const {
|
||||||
|
std::string s{"OS: "};
|
||||||
|
switch (os_) {
|
||||||
|
case OS::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case OS::Linux:
|
||||||
|
s += "Linux";
|
||||||
|
break;
|
||||||
|
case OS::Windows:
|
||||||
|
s += "Windows";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid OS type!");
|
||||||
|
}
|
||||||
|
s += "\t";
|
||||||
|
s += "Arch: ";
|
||||||
|
|
||||||
|
switch (arch_) {
|
||||||
|
case Arch::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case Arch::X86:
|
||||||
|
s += "X86";
|
||||||
|
break;
|
||||||
|
case Arch::Arm:
|
||||||
|
s += "Arm";
|
||||||
|
break;
|
||||||
|
case Arch::NVGPU:
|
||||||
|
s += "NVGPU";
|
||||||
|
break;
|
||||||
|
case Arch::AMDGPU:
|
||||||
|
s += "AMDGPU";
|
||||||
|
break;
|
||||||
|
case Arch::Ascend:
|
||||||
|
s += "Ascend";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid Arch type!");
|
||||||
|
}
|
||||||
|
s += "\t";
|
||||||
|
s += "BitLen: ";
|
||||||
|
|
||||||
|
switch (bitlen_) {
|
||||||
|
case BitLen::Unk:
|
||||||
|
s += "Unk";
|
||||||
|
break;
|
||||||
|
case BitLen::k32:
|
||||||
|
s += "k32";
|
||||||
|
break;
|
||||||
|
case BitLen::k64:
|
||||||
|
s += "k64";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Invalid target bit length!");
|
||||||
|
}
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
OS os() const { return os_; }
|
||||||
|
Arch arch() const { return arch_; }
|
||||||
|
BitLen bitlen() const { return bitlen_; }
|
||||||
|
|
||||||
|
static Target DefaultX86Target();
|
||||||
|
static Target DefaultArmTarget();
|
||||||
|
static Target DefaultRocmTarget();
|
||||||
|
static Target DefaultAscendTarget();
|
||||||
|
|
||||||
|
static Target DefaultCUDATarget() {
|
||||||
|
return Target(OS::Linux, Arch::CUDA, BitLen::k64);
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const Target& target);
|
||||||
|
friend bool operator==(const Target& lhs, const Target& rhs);
|
||||||
|
friend bool operator!=(const Target& lhs, const Target& rhs);
|
||||||
|
|
||||||
|
private:
|
||||||
|
OS os_{OS::Unk};
|
||||||
|
Arch arch_{Arch::Unk};
|
||||||
|
BitLen bitlen_{BitLen::Unk};
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& os, const Target& target) {
|
||||||
|
std::cout << target.str() << std::endl;
|
||||||
|
}
|
||||||
|
bool operator==(const Target& lhs, const Target& rhs) {
|
||||||
|
return (lhs.os_ == rhs.os_) && (lhs.arch_ == rhs.arch_) &&
|
||||||
|
(lhs.bitlen_ == rhs.bitlen_);
|
||||||
|
}
|
||||||
|
bool operator!=(const Target& lhs, const Target& rhs) {
|
||||||
|
return (lhs.os_ != rhs.os_) && (lhs.arch_ != rhs.arch_) &&
|
||||||
|
(lhs.bitlen_ != rhs.bitlen_);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace common
|
||||||
|
} // namespace colossalAI
|
98
extensions/csrc/common/vector_copy_utils.h
Normal file
98
extensions/csrc/common/vector_copy_utils.h
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include <cfloat>
|
||||||
|
|
||||||
|
#include "string"
|
||||||
|
|
||||||
|
template <typename Datatype, int ELEMENTS_PER_LDG>
|
||||||
|
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src);
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 2>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float *)dst) = *((float *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::BFloat16, 8>(
|
||||||
|
c10::BFloat16 *dst, const c10::BFloat16 *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 2>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float *)dst) = *((float *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<c10::Half, 8>(c10::Half *dst,
|
||||||
|
const c10::Half *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 1>(float *dst, const float *src) {
|
||||||
|
*dst = *src;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 2>(float *dst, const float *src) {
|
||||||
|
*((float2 *)dst) = *((float2 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 4>(float *dst, const float *src) {
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __inline__ void copy_vector<float, 8>(float *dst, const float *src) {
|
||||||
|
// Since the maximum memory alignment length is 128 bits, we choose float4
|
||||||
|
// here.
|
||||||
|
*((float4 *)dst) = *((float4 *)src);
|
||||||
|
*((float4 *)(dst + 4)) = *((float4 *)(src + 4));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
int get_vec_size(const torch::Tensor &tensor) {
|
||||||
|
uint64_t address = reinterpret_cast<uint64_t>(tensor.data_ptr<T>());
|
||||||
|
const int max_aligned_size = 128;
|
||||||
|
const int dtype_size = sizeof(T) * 8;
|
||||||
|
|
||||||
|
const int vec_size = max_aligned_size / sizeof(T) / 8;
|
||||||
|
|
||||||
|
if (address % (dtype_size * 4) == 0) {
|
||||||
|
return std::min(4, vec_size);
|
||||||
|
} else if (address % (dtype_size * 2) == 0) {
|
||||||
|
return std::min(2, vec_size);
|
||||||
|
} else {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
68
extensions/csrc/cuda/activation_kernel.cu
Normal file
68
extensions/csrc/cuda/activation_kernel.cu
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#include "../common/micros.h"
|
||||||
|
#include "../common/mp_type_traits.h"
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T silu_kernel(const T& x) {
|
||||||
|
// x * sigmoid(x)
|
||||||
|
using MT = typename colossalAI::common::MPTypeTrait<T>::Type;
|
||||||
|
return static_cast<T>((static_cast<MT>(x)) / (static_cast<MT>(1.0f) + expf(static_cast<MT>(-x))));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
|
||||||
|
__global__ void act_and_mul_kernel(
|
||||||
|
const scalar_t* __restrict__ ins_data,
|
||||||
|
scalar_t* __restrict__ outs_data,
|
||||||
|
const int64_t numel) {
|
||||||
|
using MT = typename colossalAI::common::MPTypeTrait<scalar_t>::Type;
|
||||||
|
|
||||||
|
int64_t idx = static_cast<int64_t>(threadIdx.x) + static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(blockDim.x);
|
||||||
|
const int64_t grid_size = blockDim.x * gridDim.x;
|
||||||
|
if(idx > numel) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int64_t i = idx; i < numel; i += grid_size) {
|
||||||
|
scalar_t x = ins_data[i];
|
||||||
|
scalar_t y = ins_data[i+numel];
|
||||||
|
outs_data[i] = static_cast<scalar_t>(static_cast<MT>(ACT_FN(x)) * static_cast<MT>(y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note(LiuYang):This func is designed for calculation mode like
|
||||||
|
// silu(x[:half_1stdim]) * (x[half_1stdim:])
|
||||||
|
torch::Tensor silu_and_mul(const torch::Tensor& ins)
|
||||||
|
{
|
||||||
|
auto ins_shape = ins.sizes().vec();
|
||||||
|
|
||||||
|
ins_shape[0] = ins_shape[0]/2;
|
||||||
|
if (ins_shape[0] == 1) {
|
||||||
|
ins_shape.erase(ins_shape.begin());
|
||||||
|
}
|
||||||
|
auto outs = torch::zeros(ins_shape,ins.options());
|
||||||
|
auto outs_shape = ins.sizes().vec();
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
// Note(Liuyang): numel of ins must be divisible by 2
|
||||||
|
int64_t numel = ((torch::numel(ins)) >> 1);
|
||||||
|
|
||||||
|
// TODO(LiuYang): Maybe we need to implement a function to get launch config
|
||||||
|
dim3 grid((numel+255)/256);
|
||||||
|
dim3 block(256);
|
||||||
|
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
ins.scalar_type(),
|
||||||
|
"silu_and_mul",
|
||||||
|
act_and_mul_kernel<scalar_t,silu_kernel<scalar_t>><<<grid, block, 0, stream>>>(
|
||||||
|
ins.data_ptr<scalar_t>(),
|
||||||
|
outs.data_ptr<scalar_t>(),
|
||||||
|
numel
|
||||||
|
);)
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
return outs;
|
||||||
|
}
|
@ -1,15 +0,0 @@
|
|||||||
#include <torch/extension.h>
|
|
||||||
|
|
||||||
void decode_kv_cache_memcpy(
|
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
|
||||||
torch::Tensor&
|
|
||||||
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
|
||||||
torch::Tensor& sequence_lengths, // [batch_size]
|
|
||||||
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
|
||||||
|
|
||||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
||||||
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
|
||||||
"Copy the GPU memory of kvcache during the decode stage.");
|
|
||||||
}
|
|
@ -1,10 +0,0 @@
|
|||||||
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
|
|
||||||
#ifndef TORCH_CHECK
|
|
||||||
#define TORCH_CHECK AT_CHECK
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef VERSION_GE_1_3
|
|
||||||
#define DATA_PTR data_ptr
|
|
||||||
#else
|
|
||||||
#define DATA_PTR data
|
|
||||||
#endif
|
|
@ -1,10 +1,10 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#include "type_shim.h"
|
#include "../common/vector_copy_utils.h"
|
||||||
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t, int VecSize>
|
||||||
__global__ void decode_kv_cache_memcpy_kernel(
|
__global__ void decode_kv_cache_memcpy_kernel(
|
||||||
const scalar_t* __restrict__ key,
|
const scalar_t* __restrict__ key,
|
||||||
const scalar_t* __restrict__ value,
|
const scalar_t* __restrict__ value,
|
||||||
@ -12,79 +12,146 @@ __global__ void decode_kv_cache_memcpy_kernel(
|
|||||||
scalar_t* __restrict__ value_cache,
|
scalar_t* __restrict__ value_cache,
|
||||||
const int* __restrict__ sequence_lengths,
|
const int* __restrict__ sequence_lengths,
|
||||||
const int* __restrict__ block_tables,
|
const int* __restrict__ block_tables,
|
||||||
const int num_heads,
|
const int head_num,
|
||||||
const int head_size,
|
const int head_dim,
|
||||||
const int block_size,
|
const int block_size,
|
||||||
const int key_stride,
|
const int64_t key_stride,
|
||||||
const int value_stride,
|
const int64_t value_stride,
|
||||||
const int block_table_stride
|
const int block_table_stride
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
const int seq_id = blockIdx.x;
|
const int seq_id = blockIdx.x;
|
||||||
const int seq_len = sequence_lengths[seq_id] - 1;
|
const int seq_len = sequence_lengths[seq_id] - 1;
|
||||||
const int seq_id_in_block_table = seq_len / block_size;
|
|
||||||
const int block_offset = seq_len % block_size;
|
const int block_offset = seq_len % block_size;
|
||||||
const int block_id = block_tables[seq_id * block_table_stride + seq_id_in_block_table];
|
const int block_id = block_tables[seq_id * block_table_stride + seq_len / block_size];
|
||||||
const int hidden_size = num_heads * head_size;
|
const int hidden_size = head_num * head_dim;
|
||||||
|
|
||||||
if ( block_id < 0 ) {
|
if ( block_id < 0 ) {
|
||||||
return ;
|
return ;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) {
|
||||||
const int head_id = i / head_size;
|
const int head_id = i / head_dim;
|
||||||
const int head_offset = i % head_size;
|
const int head_offset = i % head_dim;
|
||||||
const int key_src_id = seq_id * key_stride + i;
|
const int64_t key_src_id = seq_id * key_stride + i;
|
||||||
const int value_src_id = seq_id * value_stride + i;
|
const int64_t value_src_id = seq_id * value_stride + i;
|
||||||
const int target_src_id = block_id * hidden_size * block_size
|
const int64_t target_id = block_id * hidden_size * block_size
|
||||||
+ head_id * block_size * head_size
|
+ head_id * block_size * head_dim
|
||||||
+ block_offset * head_size + head_offset;
|
+ block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
key_cache[target_src_id] = key[key_src_id];
|
copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
|
||||||
value_cache[target_src_id] = value[value_src_id];
|
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void decode_kv_cache_memcpy(
|
template<typename scalar_t>
|
||||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
void apply_decode_kv_cache_memcpy(
|
||||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||||
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||||
torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_size]
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
torch::Tensor& sequence_lengths, // [batch_size]
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
torch::Tensor& block_tables) // [batch_size, max_seq_len]
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
{
|
{
|
||||||
int num_tokens = key.size(0);
|
int num_tokens = key.size(0);
|
||||||
int num_heads = key.size(1);
|
int head_num = key.size(1);
|
||||||
int head_size = key.size(2);
|
int head_dim = key.size(2);
|
||||||
int block_size = key_cache.size(2);
|
int block_size = key_cache.size(2);
|
||||||
|
|
||||||
int key_stride = key.stride(0);
|
int64_t key_stride = key.stride(0);
|
||||||
int value_stride = value.stride(0);
|
int64_t value_stride = value.stride(0);
|
||||||
int block_table_stride = block_tables.stride(0);
|
int block_table_stride = block_tables.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(key);
|
||||||
|
|
||||||
|
if (head_dim % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size;
|
||||||
|
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
dim3 grid(num_tokens);
|
dim3 grid(num_tokens);
|
||||||
dim3 block(std::min(num_heads * head_size, 512));
|
dim3 block(std::min(thread_nums, 512));
|
||||||
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
|
||||||
key.scalar_type(),
|
switch (vec_size) {
|
||||||
"decode_kv_cache_memcpy",
|
case 1:
|
||||||
decode_kv_cache_memcpy_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
decode_kv_cache_memcpy_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(),
|
||||||
value_cache.data_ptr<scalar_t>(),
|
value_cache.data_ptr<scalar_t>(),
|
||||||
sequence_lengths.data_ptr<int>(),
|
sequence_lengths.data_ptr<int>(),
|
||||||
block_tables.data_ptr<int>(),
|
block_tables.data_ptr<int>(),
|
||||||
num_heads,
|
head_num,
|
||||||
head_size,
|
head_dim,
|
||||||
block_size,
|
block_size,
|
||||||
key_stride,
|
key_stride,
|
||||||
value_stride,
|
value_stride,
|
||||||
block_table_stride
|
block_table_stride
|
||||||
);)
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
decode_kv_cache_memcpy_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
block_table_stride
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
decode_kv_cache_memcpy_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
block_size,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
block_table_stride
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void decode_kv_cache_memcpy(
|
||||||
|
at::Tensor& key, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
key.scalar_type(),
|
||||||
|
"decode_kv_cache_memcpy",
|
||||||
|
apply_decode_kv_cache_memcpy<scalar_t>(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables
|
||||||
|
);)
|
||||||
|
}
|
||||||
|
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
472
extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu
Normal file
@ -0,0 +1,472 @@
|
|||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
#include "../common/vector_copy_utils.h"
|
||||||
|
#include "../common/micros.h"
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_emb_rotary_compute(
|
||||||
|
scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr,
|
||||||
|
const scalar_t* __restrict__ sin_ptr, const int64_t stride,
|
||||||
|
const int token_id, const int shard_block_size, const int half_head_dim,
|
||||||
|
const int head_num, const int head_dim) {
|
||||||
|
scalar_t x[VecSize];
|
||||||
|
scalar_t y[VecSize];
|
||||||
|
scalar_t out_x[VecSize];
|
||||||
|
scalar_t out_y[VecSize];
|
||||||
|
|
||||||
|
for (int i = threadIdx.x * VecSize; i < head_num * half_head_dim;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int shard_offset =
|
||||||
|
(head_offset / shard_block_size) * shard_block_size +
|
||||||
|
(head_offset % shard_block_size) / VecSize;
|
||||||
|
const int64_t addr_offset =
|
||||||
|
token_id * stride + (i / half_head_dim) * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(x, src + addr_offset);
|
||||||
|
copy_vector<scalar_t, VecSize>(y, src + addr_offset + half_head_dim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < VecSize; j++) {
|
||||||
|
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||||
|
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||||
|
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(src + addr_offset, out_x);
|
||||||
|
copy_vector<scalar_t, VecSize>(src + addr_offset + half_head_dim, out_y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_kv_memcopy(
|
||||||
|
scalar_t* __restrict__ src, scalar_t* __restrict__ cache,
|
||||||
|
const int64_t stride, const int token_id, const int block_id,
|
||||||
|
const int hidden_size, const int block_size, const int block_offset,
|
||||||
|
const int head_dim, const int half_head_dim) {
|
||||||
|
for (int i = threadIdx.x * VecSize; i < hidden_size / 2;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_id = i / half_head_dim;
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int64_t src_id = token_id * stride + head_id * head_dim + head_offset;
|
||||||
|
const int64_t target_id = block_id * hidden_size * block_size +
|
||||||
|
head_id * block_size * head_dim +
|
||||||
|
block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(cache + target_id, src + src_id);
|
||||||
|
copy_vector<scalar_t, VecSize>(cache + target_id + half_head_dim,
|
||||||
|
src + src_id + half_head_dim);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void cos_sin_memory_access(
|
||||||
|
const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin,
|
||||||
|
scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id,
|
||||||
|
const int shard_block_size, const int cos_stride, const int sin_stride,
|
||||||
|
const int half_head_dim) {
|
||||||
|
for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) {
|
||||||
|
// We assume that the value of head_dim is less than 128*128.
|
||||||
|
const int shard_offset = (i % shard_block_size) / VecSize;
|
||||||
|
const int shard_head =
|
||||||
|
(i / shard_block_size) * shard_block_size + i % VecSize * 32;
|
||||||
|
cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i];
|
||||||
|
sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, int VecSize>
|
||||||
|
__device__ void apply_k_rotary_emb_compute(
|
||||||
|
scalar_t* __restrict__ key, scalar_t* __restrict__ value,
|
||||||
|
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
|
||||||
|
const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr,
|
||||||
|
const int* __restrict__ sequence_lengths,
|
||||||
|
const int* __restrict__ block_tables, const int64_t key_stride,
|
||||||
|
const int64_t value_stride, const int token_id,
|
||||||
|
const int block_table_stride, const int head_num, const int head_dim,
|
||||||
|
const int kv_head_num, const int block_size, const int half_head_dim,
|
||||||
|
const int shard_block_size) {
|
||||||
|
const int seq_len = sequence_lengths[token_id] - 1;
|
||||||
|
const int block_offset = seq_len % block_size;
|
||||||
|
const int block_id =
|
||||||
|
block_tables[token_id * block_table_stride + seq_len / block_size];
|
||||||
|
|
||||||
|
if (block_id < 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
scalar_t x[VecSize];
|
||||||
|
scalar_t y[VecSize];
|
||||||
|
scalar_t out_x[VecSize];
|
||||||
|
scalar_t out_y[VecSize];
|
||||||
|
|
||||||
|
for (int i = threadIdx.x * VecSize; i < kv_head_num * half_head_dim;
|
||||||
|
i += blockDim.x * VecSize) {
|
||||||
|
const int head_offset = i % half_head_dim;
|
||||||
|
const int shard_offset =
|
||||||
|
(head_offset / shard_block_size) * shard_block_size +
|
||||||
|
(head_offset % shard_block_size) / VecSize;
|
||||||
|
const int64_t addr_offset =
|
||||||
|
token_id * key_stride + (i / half_head_dim) * head_dim + head_offset;
|
||||||
|
const int64_t target_id = block_id * head_num * head_dim * block_size +
|
||||||
|
(i / half_head_dim) * block_size * head_dim +
|
||||||
|
block_offset * head_dim + head_offset;
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(x, key + addr_offset);
|
||||||
|
copy_vector<scalar_t, VecSize>(y, key + addr_offset + half_head_dim);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < VecSize; j++) {
|
||||||
|
out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] -
|
||||||
|
y[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] +
|
||||||
|
x[j] * sin_ptr[j * 32 + shard_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
copy_vector<scalar_t, VecSize>(key_cache + target_id, out_x);
|
||||||
|
copy_vector<scalar_t, VecSize>(key_cache + target_id + half_head_dim,
|
||||||
|
out_y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply value memcopy
|
||||||
|
apply_kv_memcopy<scalar_t, VecSize>(
|
||||||
|
value, value_cache, value_stride, token_id, block_id, head_num * head_dim,
|
||||||
|
block_size, block_offset, head_dim, half_head_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int VecSize>
|
||||||
|
__global__ void rotary_embedding_and_cache_copy_kernel(
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
scalar_t* __restrict__ value,
|
||||||
|
const scalar_t* __restrict__ cos,
|
||||||
|
const scalar_t* __restrict__ sin,
|
||||||
|
scalar_t* __restrict__ key_cache,
|
||||||
|
scalar_t* __restrict__ value_cache,
|
||||||
|
const int* __restrict__ sequence_lengths,
|
||||||
|
const int* __restrict__ block_tables,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride,
|
||||||
|
const int64_t value_stride,
|
||||||
|
const int64_t half_shard_element_num,
|
||||||
|
const int cos_stride,
|
||||||
|
const int sin_stride,
|
||||||
|
const int block_table_stride,
|
||||||
|
const int head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int kv_head_num,
|
||||||
|
const int block_size
|
||||||
|
) {
|
||||||
|
|
||||||
|
const int token_id = blockIdx.x;
|
||||||
|
const int half_head_dim = head_dim / 2;
|
||||||
|
const int shard_block_size = VecSize * 32;
|
||||||
|
|
||||||
|
extern __shared__ char shard_ptr[];
|
||||||
|
|
||||||
|
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||||
|
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||||
|
|
||||||
|
// apply cos_sin memcopy
|
||||||
|
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//compute query
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||||
|
|
||||||
|
//compute key and copy kv
|
||||||
|
apply_k_rotary_emb_compute<scalar_t, VecSize>(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int VecSize>
|
||||||
|
__global__ void rotary_embedding_kernel(
|
||||||
|
scalar_t* __restrict__ query,
|
||||||
|
scalar_t* __restrict__ key,
|
||||||
|
const scalar_t* __restrict__ cos,
|
||||||
|
const scalar_t* __restrict__ sin,
|
||||||
|
const int64_t query_stride,
|
||||||
|
const int64_t key_stride,
|
||||||
|
const int64_t half_shard_element_num,
|
||||||
|
const int cos_stride,
|
||||||
|
const int sin_stride,
|
||||||
|
const int head_num,
|
||||||
|
const int head_dim,
|
||||||
|
const int kv_head_num
|
||||||
|
) {
|
||||||
|
const int token_id = blockIdx.x;
|
||||||
|
const int half_head_dim = head_dim / 2;
|
||||||
|
const int shard_block_size = VecSize * 32;
|
||||||
|
|
||||||
|
extern __shared__ char shard_ptr[];
|
||||||
|
|
||||||
|
scalar_t *cos_ptr = (scalar_t*)shard_ptr;
|
||||||
|
scalar_t *sin_ptr = cos_ptr + half_shard_element_num;
|
||||||
|
|
||||||
|
// apply cos_sin memcopy
|
||||||
|
cos_sin_memory_access<scalar_t, VecSize>(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim);
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
//compute query
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim);
|
||||||
|
|
||||||
|
//compute key
|
||||||
|
apply_emb_rotary_compute<scalar_t, VecSize>(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
void apply_rotary_embedding_and_cache_copy(
|
||||||
|
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
int num_tokens = query.size(0);
|
||||||
|
int head_num = query.size(1);
|
||||||
|
int head_dim = query.size(2);
|
||||||
|
int kv_head_num = key.size(1);
|
||||||
|
int block_size = key_cache.size(2);
|
||||||
|
|
||||||
|
int64_t query_stride = query.stride(0);
|
||||||
|
int64_t key_stride = key.stride(0);
|
||||||
|
int64_t value_stride = value.stride(0);
|
||||||
|
int cos_stride = cos.stride(0);
|
||||||
|
int sin_stride = sin.stride(0);
|
||||||
|
int block_table_stride = block_tables.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(query);
|
||||||
|
|
||||||
|
if ((head_dim / 2) % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||||
|
const int shard_block_size = vec_size * 32 * 2;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||||
|
|
||||||
|
switch (vec_size) {
|
||||||
|
case 1:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
rotary_embedding_and_cache_copy_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
value_cache.data_ptr<scalar_t>(),
|
||||||
|
sequence_lengths.data_ptr<int>(),
|
||||||
|
block_tables.data_ptr<int>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
value_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
block_table_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num,
|
||||||
|
block_size
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t>
|
||||||
|
void apply_rotary_embedding(
|
||||||
|
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
at::Tensor& sin // [total_tokens, head_dim]
|
||||||
|
){
|
||||||
|
int num_tokens = query.size(0);
|
||||||
|
int head_num = query.size(1);
|
||||||
|
int head_dim = query.size(2);
|
||||||
|
int kv_head_num = key.size(1);
|
||||||
|
|
||||||
|
int query_stride = query.stride(0);
|
||||||
|
int key_stride = key.stride(0);
|
||||||
|
int cos_stride = cos.stride(0);
|
||||||
|
int sin_stride = sin.stride(0);
|
||||||
|
|
||||||
|
int vec_size = get_vec_size<scalar_t>(query);
|
||||||
|
|
||||||
|
if ((head_dim / 2) % vec_size != 0) {
|
||||||
|
// Disable vectorized loading optimization when head_dim is not divisible by VecSize.
|
||||||
|
vec_size = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
int thread_nums = head_num * head_dim / vec_size / 2;
|
||||||
|
const int shard_block_size = vec_size * 32 * 2;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(thread_nums, 512));
|
||||||
|
int64_t shard_element_num = ((head_dim + shard_block_size - 1) / shard_block_size) * shard_block_size ;
|
||||||
|
|
||||||
|
switch (vec_size) {
|
||||||
|
case 1:
|
||||||
|
rotary_embedding_kernel<scalar_t, 1><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
rotary_embedding_kernel<scalar_t, 2><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
rotary_embedding_kernel<scalar_t, 4><<<grid, block, shard_element_num * sizeof(scalar_t), stream>>>(
|
||||||
|
query.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
cos.data_ptr<scalar_t>(),
|
||||||
|
sin.data_ptr<scalar_t>(),
|
||||||
|
query_stride,
|
||||||
|
key_stride,
|
||||||
|
shard_element_num / 2,
|
||||||
|
cos_stride,
|
||||||
|
sin_stride,
|
||||||
|
head_num,
|
||||||
|
head_dim,
|
||||||
|
kv_head_num
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("Unsupported vectorized size ", vec_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
|
void rotary_embedding_and_cache_copy(
|
||||||
|
at::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& value, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
|
||||||
|
at::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
at::Tensor& block_tables) // [batch_size, max_seq_len]
|
||||||
|
{
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
query.scalar_type(),
|
||||||
|
"rotary_embedding_and_cache_copy",
|
||||||
|
apply_rotary_embedding_and_cache_copy<scalar_t>(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
sequence_lengths,
|
||||||
|
block_tables
|
||||||
|
);)
|
||||||
|
}
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
at::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
at::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
at::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
at::Tensor& sin // [total_tokens, head_dim]
|
||||||
|
){
|
||||||
|
DISPATCH_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
query.scalar_type(),
|
||||||
|
"rotary_embedding",
|
||||||
|
apply_rotary_embedding<scalar_t>(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
cos,
|
||||||
|
sin
|
||||||
|
);)
|
||||||
|
}
|
@ -310,3 +310,90 @@ __inline__ __device__ void blockReduce<ReduceType::kMax, 4>(float *pval) {
|
|||||||
}
|
}
|
||||||
warpReduce<ReduceType::kMax, num>(pval);
|
warpReduce<ReduceType::kMax, num>(pval);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T reduce_block_into_lanes(
|
||||||
|
T *x, T val, int lanes = 1,
|
||||||
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
|
{
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = x[tid] + x[tid + i];
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
|
if (blockSize >= 64)
|
||||||
|
final = x[tid] + x[tid + 32];
|
||||||
|
else
|
||||||
|
final = val;
|
||||||
|
// __SYNCWARP();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
|
final = final + __shfl_down_sync(0xffffffff, final, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
return final;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ __forceinline__ T reduce_block_into_lanes_max_op(
|
||||||
|
T *x, T val, int lanes = 1,
|
||||||
|
bool share_result = false) // lanes is intended to be <= 32.
|
||||||
|
{
|
||||||
|
int tid = threadIdx.x + threadIdx.y * blockDim.x;
|
||||||
|
int blockSize =
|
||||||
|
blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.
|
||||||
|
|
||||||
|
if (blockSize >= 64) {
|
||||||
|
x[tid] = val;
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
|
||||||
|
if (tid < i) x[tid] = fmaxf(fabsf(x[tid]), fabsf(x[tid + i]));
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
T final;
|
||||||
|
|
||||||
|
if (tid < 32) {
|
||||||
|
if (blockSize >= 64)
|
||||||
|
final = fmaxf(fabsf(x[tid]), fabsf(x[tid + 32]));
|
||||||
|
else
|
||||||
|
final = val;
|
||||||
|
// __SYNCWARP();
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 16; i >= lanes; i >>= 1)
|
||||||
|
final =
|
||||||
|
fmaxf(fabsf(final), fabsf(__shfl_down_sync(0xffffffff, final, i)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (share_result) {
|
||||||
|
if (tid < lanes) x[tid] = final; // EpilogueOp
|
||||||
|
// Make sure the smem result is visible to all warps.
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
return final;
|
||||||
|
}
|
||||||
|
@ -9,7 +9,7 @@
|
|||||||
#include "ATen/AccumulateType.h"
|
#include "ATen/AccumulateType.h"
|
||||||
#include "ATen/cuda/CUDAContext.h"
|
#include "ATen/cuda/CUDAContext.h"
|
||||||
#include "ATen/cuda/DeviceUtils.cuh"
|
#include "ATen/cuda/DeviceUtils.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template <typename U>
|
template <typename U>
|
||||||
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
__device__ void cuWelfordOnlineSum(const U curr, U& mu, U& sigma2, U& count) {
|
@ -15,7 +15,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -12,7 +12,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
// #include <iostream>
|
// #include <iostream>
|
||||||
|
|
||||||
|
@ -11,7 +11,8 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
#include "include/block_reduce.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
@ -10,7 +10,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
@ -10,7 +10,7 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
#define ILP 4
|
#define ILP 4
|
||||||
|
@ -7,7 +7,7 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../common/micros.h"
|
||||||
#include "multi_tensor_apply.cuh"
|
#include "multi_tensor_apply.cuh"
|
||||||
|
|
||||||
#define BLOCK_SIZE 512
|
#define BLOCK_SIZE 512
|
||||||
|
59
extensions/csrc/cuda/pybind/inference.cpp
Normal file
59
extensions/csrc/cuda/pybind/inference.cpp
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
void decode_kv_cache_memcpy(
|
||||||
|
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, block_size, head_size]
|
||||||
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
|
|
||||||
|
void rotary_embedding(
|
||||||
|
torch::Tensor& query, // [total_tokens, head_num, head_dim]
|
||||||
|
torch::Tensor& key, // [total_tokens, kv_head_num, head_dim]
|
||||||
|
torch::Tensor& cos, // [total_tokens, head_dim]
|
||||||
|
torch::Tensor& sin); // [total_tokens, head_dim]
|
||||||
|
|
||||||
|
void rotary_embedding_and_cache_copy(
|
||||||
|
torch::Tensor& query, // [num_tokens, head_num, head_dim]
|
||||||
|
torch::Tensor& key, // [num_tokens, kv_head_num, head_dim]
|
||||||
|
torch::Tensor& value, // [num_tokens, num_heads, head_dim]
|
||||||
|
torch::Tensor& cos, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& sin, // [num_tokens, head_dim]
|
||||||
|
torch::Tensor& key_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||||
|
torch::Tensor&
|
||||||
|
value_cache, // [num_blocks, num_heads, block_size, head_dim]
|
||||||
|
torch::Tensor& sequence_lengths, // [batch_size]
|
||||||
|
torch::Tensor& block_tables); // [batch_size, max_seq_len]
|
||||||
|
torch::Tensor silu_and_mul(const torch::Tensor& ins);
|
||||||
|
|
||||||
|
void rms_layernorm(torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
void fused_add_rms_layernorm(torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
|
||||||
|
"Copy the GPU memory of kvcache during the decode stage.");
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy,
|
||||||
|
"performing Rotary Embedding-related calculations and KVCache Memcopy.");
|
||||||
|
|
||||||
|
m.def("rotary_embedding", &rotary_embedding,
|
||||||
|
"performing Rotary Embedding-related calculations.");
|
||||||
|
|
||||||
|
m.def("silu_and_mul", &silu_and_mul, "Silu with a following multiply");
|
||||||
|
|
||||||
|
m.def("rms_layernorm", &rms_layernorm,
|
||||||
|
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
|
||||||
|
|
||||||
|
m.def("fused_add_rms_layernorm", &fused_add_rms_layernorm,
|
||||||
|
"In-place fused Add and RMS Normalization.");
|
||||||
|
}
|
@ -7,7 +7,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compat.h"
|
#include "../../common/micros.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
412
extensions/csrc/cuda/rms_layernorm_kernel.cu
Normal file
412
extensions/csrc/cuda/rms_layernorm_kernel.cu
Normal file
@ -0,0 +1,412 @@
|
|||||||
|
/*This code from VLLM:
|
||||||
|
* https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
|
||||||
|
* with minor changes. */
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include "block_reduce.h"
|
||||||
|
#include "../common/micros.h"
|
||||||
|
#include "../common/cuda_type_utils.h"
|
||||||
|
|
||||||
|
#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
|
||||||
|
if (DATA_SIZE == 2) { \
|
||||||
|
switch (TYPE) { \
|
||||||
|
case at::ScalarType::Half: { \
|
||||||
|
using scalar_t = at::Half; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
case at::ScalarType::BFloat16: { \
|
||||||
|
using scalar_t = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
} \
|
||||||
|
} else { \
|
||||||
|
switch (TYPE) { \
|
||||||
|
case at::ScalarType::Float: { \
|
||||||
|
using scalar_t = float; \
|
||||||
|
general_##__VA_ARGS__; \
|
||||||
|
break; \
|
||||||
|
} \
|
||||||
|
default: \
|
||||||
|
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
// optimized for half and bf16
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
|
__global__ void rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
|
||||||
|
/*
|
||||||
|
* since the open-sourced LLM's hidden dimensions mainly range from
|
||||||
|
* 4096 (LLAMA-7B) to 8192 (LLAMA-65B), we thus set the supported
|
||||||
|
* hidden dimension limit to 8192, and each thread's capacity
|
||||||
|
* for caching input tensors to 8 (8192 = 8 * 1024) which
|
||||||
|
* will cause problems for extremely large models, such as
|
||||||
|
* Megatron-Turing NLG 530B with hidden dimensions up to 20480
|
||||||
|
*/
|
||||||
|
scalar2_t x_local[4];
|
||||||
|
|
||||||
|
scalar2_t* out_ptr = (scalar2_t*)out;
|
||||||
|
const scalar2_t* input_ptr = (scalar2_t*)input;
|
||||||
|
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
|
||||||
|
|
||||||
|
float variance = 0.0f;
|
||||||
|
int row_offset = blockIdx.x * hidden_size / 2;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = input_ptr[id];
|
||||||
|
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||||
|
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||||
|
variance += v1 * v1 + v2 * v2;
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
|
__global__ void general_rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
float x_local[8];
|
||||||
|
|
||||||
|
int row_offset = blockIdx.x * hidden_size;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = (float) input[id];
|
||||||
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
out[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// optimized for half and bf16
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
|
__global__ void fused_add_rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
using scalar2_t = typename TypeConverter<scalar_t>::Type;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
scalar2_t x_local[4];
|
||||||
|
|
||||||
|
scalar2_t* input_ptr = (scalar2_t*)input;
|
||||||
|
scalar2_t* residual_ptr = (scalar2_t*)residual;
|
||||||
|
const scalar2_t* weight_ptr = (const scalar2_t*)weight;
|
||||||
|
|
||||||
|
float variance = 0.0f;
|
||||||
|
int row_offset = blockIdx.x * hidden_size / 2;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = input_ptr[id];
|
||||||
|
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
|
||||||
|
float v1 = cuda_cast<float>(x_local[cnt].x);
|
||||||
|
float v2 = cuda_cast<float>(x_local[cnt].y);
|
||||||
|
variance += v1 * v1 + v2 * v2;
|
||||||
|
residual_ptr[id] = x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename scalar_t, int unroll_factor>
|
||||||
|
__global__ void general_fused_add_rms_layernorm_kernel(
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float epsilon,
|
||||||
|
const int num_tokens,
|
||||||
|
const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
float x_local[8];
|
||||||
|
|
||||||
|
int row_offset = blockIdx.x * hidden_size;
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
x_local[cnt] = (float) input[id];
|
||||||
|
x_local[cnt] += (float) residual[id];
|
||||||
|
variance += x_local[cnt] * x_local[cnt];
|
||||||
|
residual[id] = (scalar_t) x_local[cnt];
|
||||||
|
}
|
||||||
|
variance = blockReduceSum<float>(variance);
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
#pragma unroll unroll_factor
|
||||||
|
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size; idx += blockDim.x, cnt++) {
|
||||||
|
int id = row_offset + idx;
|
||||||
|
input[id] = ((scalar_t) (x_local[cnt] * s_variance)) * weight[idx];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rms_layernorm(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (num_tokens >= 512) {
|
||||||
|
if (input.scalar_type() == at::ScalarType::Float) {
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
} else {
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||||
|
if (input.scalar_type() != at::ScalarType::Float) {
|
||||||
|
block.x = std::min(hidden_size / 2, 1024);
|
||||||
|
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||||
|
}
|
||||||
|
switch (unroll_factor) {
|
||||||
|
case 1:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"rms_layernorm_kernel",
|
||||||
|
rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fused_add_rms_layernorm(
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
float epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
if (num_tokens >= 512) {
|
||||||
|
if (input.scalar_type() == at::ScalarType::Float) {
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
} else {
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, hidden_size / 8, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
int unroll_factor = (hidden_size + block.x - 1) / block.x;
|
||||||
|
if (input.scalar_type() != at::ScalarType::Float) {
|
||||||
|
block.x = std::min(hidden_size / 2, 1024);
|
||||||
|
unroll_factor = (hidden_size / 2 + block.x - 1) / block.x;
|
||||||
|
}
|
||||||
|
switch (unroll_factor) {
|
||||||
|
case 1:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 1><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 2><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 4><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(
|
||||||
|
input.element_size(),
|
||||||
|
input.scalar_type(),
|
||||||
|
"fused_add_rms_layernorm_kernel",
|
||||||
|
fused_add_rms_layernorm_kernel<scalar_t, 8><<<grid, block, 0, stream>>>(
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
residual.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(),
|
||||||
|
epsilon,
|
||||||
|
num_tokens,
|
||||||
|
hidden_size);)
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
AT_ERROR("unroll_factor must be 1, 2, 4 or 8");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -10,7 +10,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "scaled_masked_softmax.h"
|
#include "scaled_masked_softmax.h"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
namespace multihead_attn {
|
namespace multihead_attn {
|
||||||
namespace fused_softmax {
|
namespace fused_softmax {
|
@ -10,7 +10,7 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include "scaled_upper_triang_masked_softmax.h"
|
#include "scaled_upper_triang_masked_softmax.h"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
namespace multihead_attn {
|
namespace multihead_attn {
|
||||||
namespace fused_softmax {
|
namespace fused_softmax {
|
36
extensions/csrc/cuda/utils/gpu_launch_config.h
Normal file
36
extensions/csrc/cuda/utils/gpu_launch_config.h
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
// TODO(LiuYang): to be implemented
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig2D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
// TODO(LiuYang): to be implemented
|
||||||
|
GPULaunchConfig GPUGetGPULaunchConfig3D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
class GPULaunchConfig {
|
||||||
|
public:
|
||||||
|
GPULaunchConfig(){};
|
||||||
|
GPULaunchConfig(const dim3& block, const dim3& grid)
|
||||||
|
: block_(block), grid_(grid) {}
|
||||||
|
friend GPULaunchConfig GPUGetGPULaunchConfig1D(int64_t numel, int vec_size);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void set_block(const dim3& dim) { block_ = dim; }
|
||||||
|
void set_grid(const dim3& dim) { grid_ = dim; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
dim3 block_(1, 1, 1);
|
||||||
|
dim3 grid_(1, 1, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
12
extensions/csrc/cuda/utils/micros.h
Normal file
12
extensions/csrc/cuda/utils/micros.h
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#define CUDA_CHECK(func) \
|
||||||
|
{ \
|
||||||
|
auto status = func; \
|
||||||
|
if (status != cudaSuccess) { \
|
||||||
|
LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \
|
||||||
|
} \
|
||||||
|
}
|
45
extensions/csrc/cuda/utils/nvgpu_dev_info.cc
Normal file
45
extensions/csrc/cuda/utils/nvgpu_dev_info.cc
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
#include "nvgpu_dev_info.h"
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
std::array<int, 3> NVGPUDevInfo::GetMaxGridDims() const {
|
||||||
|
std::array<int, 3> ret;
|
||||||
|
ret[0] = prop_->maxGridSize[0];
|
||||||
|
ret[1] = prop_->maxGridSize[1];
|
||||||
|
ret[2] = prop_->maxGridSize[2];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 3> NVGPUDevInfo::GetMaxBlockDims() const {
|
||||||
|
std::array<int, 3> ret;
|
||||||
|
ret[0] = prop_->maxThreadsDim[0];
|
||||||
|
ret[1] = prop_->maxThreadsDim[1];
|
||||||
|
ret[2] = prop_->maxThreadsDim[2];
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 2> NVGPUDevInfo::GetCapability() const {
|
||||||
|
std::array<int, 2> ret;
|
||||||
|
ret[0] = prop_.major;
|
||||||
|
ret[1] = prop_.minor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMultiProcessorCount() const {
|
||||||
|
return prop_->multiProcessorCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMaxThreadsPerMultiProcessor() const {
|
||||||
|
return prop_->maxThreadsPerMultiProcessor;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NVGPUDevInfo::GetMaxThreadsPerBlock() const {
|
||||||
|
return prop_->maxThreadsPerBlock;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
37
extensions/csrc/cuda/utils/nvgpu_dev_info.h
Normal file
37
extensions/csrc/cuda/utils/nvgpu_dev_info.h
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cuda.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <ostream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "micros.h"
|
||||||
|
#include "target.h"
|
||||||
|
|
||||||
|
namespace colossalAI {
|
||||||
|
namespace cuda {
|
||||||
|
namespace utils {
|
||||||
|
|
||||||
|
class NVGPUDevInfo {
|
||||||
|
public:
|
||||||
|
explicit NVGPUDevInfo(int device_num) : device_num_(device_num) {
|
||||||
|
CUDA_CALL(cudaGetDeviceProperties(prop_, device));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::array<int, 3> GetMaxGridDims() const;
|
||||||
|
std::array<int, 3> GetMaxBlockDims() const;
|
||||||
|
std::array<int, 2> GetCapability() const;
|
||||||
|
int GetMultiProcessorCount() const;
|
||||||
|
int GetMaxThreadsPerMultiProcessor() const;
|
||||||
|
int GetMaxThreadsPerBlock() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int device_num_;
|
||||||
|
cudaDeviceProp* prop_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace utils
|
||||||
|
} // namespace cuda
|
||||||
|
} // namespace colossalAI
|
@ -10,14 +10,17 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
|||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/colossal_inference_C_frontend.cpp",
|
"cuda/pybind/inference.cpp",
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
|
"cuda/fused_rotary_emb_and_cache_kernel.cu",
|
||||||
|
"cuda/activation_kernel.cu",
|
||||||
|
"cuda/rms_layernorm_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self):
|
||||||
ret = [self.get_cuda_home_include()]
|
ret = [self.csrc_abs_path("cuda/include"), self.get_cuda_home_include()]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
|
@ -7,7 +7,7 @@ class LayerNormCudaExtension(_CudaExtension):
|
|||||||
super().__init__(name="layernorm_cuda")
|
super().__init__(name="layernorm_cuda")
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self):
|
||||||
|
@ -11,7 +11,7 @@ class MoeCudaExtension(_CudaExtension):
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/moe.cpp", "cuda/moe_kernel.cu"]]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
|
@ -10,12 +10,12 @@ class FusedOptimizerCudaExtension(_CudaExtension):
|
|||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/colossal_C_frontend.cpp",
|
"cuda/pybind/optimizer.cpp",
|
||||||
"cuda/multi_tensor_sgd_kernel.cu",
|
"cuda/multi_tensor_sgd_kernel.cu",
|
||||||
"cuda/multi_tensor_scale_kernel.cu",
|
"cuda/multi_tensor_scale_kernel.cu",
|
||||||
"cuda/multi_tensor_adam.cu",
|
"cuda/multi_tensor_adam_kernel.cu",
|
||||||
"cuda/multi_tensor_l2norm_kernel.cu",
|
"cuda/multi_tensor_l2norm_kernel.cu",
|
||||||
"cuda/multi_tensor_lamb.cu",
|
"cuda/multi_tensor_lamb_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
@ -9,7 +9,7 @@ class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
|||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"]
|
for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@ class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
|||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/scaled_upper_triang_masked_softmax.cpp",
|
"cuda/pybind/scaled_upper_triang_masked_softmax.cpp",
|
||||||
"cuda/scaled_upper_triang_masked_softmax_cuda.cu",
|
"cuda/scaled_upper_triang_masked_softmax_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
@ -22,15 +22,11 @@ def setup_seed(seed):
|
|||||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||||
setup_seed(20)
|
setup_seed(20)
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
model = (
|
model = LlamaForCausalLM(
|
||||||
LlamaForCausalLM(
|
LlamaConfig(
|
||||||
LlamaConfig(
|
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
||||||
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
.cuda()
|
).cuda()
|
||||||
.half()
|
|
||||||
)
|
|
||||||
model = model.eval()
|
model = model.eval()
|
||||||
|
|
||||||
inputs = [
|
inputs = [
|
||||||
@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||||||
top_k = 50
|
top_k = 50
|
||||||
|
|
||||||
if use_engine:
|
if use_engine:
|
||||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
|
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||||
inference_engine.add_request(prompts=inputs)
|
inference_engine.add_request(prompts=inputs)
|
||||||
|
51
tests/test_infer/test_ops/cuda/test_rms_layernorm.py
Normal file
51
tests/test_infer/test_ops/cuda/test_rms_layernorm.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("M", [2, 4, 8, 16])
|
||||||
|
@pytest.mark.parametrize("N", [64, 128, 512])
|
||||||
|
def test_rms_layernorm(M: int, N: int):
|
||||||
|
torch.manual_seed(123)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
device = get_current_device()
|
||||||
|
|
||||||
|
dtype = torch.float16
|
||||||
|
eps = 1e-5
|
||||||
|
x_shape = (M, N)
|
||||||
|
w_shape = (x_shape[-1],)
|
||||||
|
weight = torch.ones(w_shape, dtype=dtype, device=device)
|
||||||
|
residual = torch.rand(x_shape, dtype=dtype, device=device)
|
||||||
|
residual_copy = residual.clone()
|
||||||
|
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
|
||||||
|
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||||
|
x_copy = x.clone()
|
||||||
|
|
||||||
|
y_cuda = torch.empty_like(x)
|
||||||
|
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert y_cuda.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
|
||||||
|
y_cuda = x
|
||||||
|
|
||||||
|
x = x_copy + residual_copy
|
||||||
|
y_llama = rms_norm.forward(x).to(dtype)
|
||||||
|
|
||||||
|
assert y_cuda.shape == y_llama.shape
|
||||||
|
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
|
||||||
|
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rms_layernorm(16, 512)
|
91
tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
Normal file
91
tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
|
||||||
|
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("BATCH_SIZE", [4])
|
||||||
|
@pytest.mark.parametrize("SEQ_LEN", [64])
|
||||||
|
@pytest.mark.parametrize("H", [32])
|
||||||
|
@pytest.mark.parametrize("D", [64])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||||||
|
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
|
||||||
|
torch.manual_seed(10)
|
||||||
|
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
|
||||||
|
# our crafted op equals to Transformers
|
||||||
|
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||||
|
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
|
||||||
|
emb = LlamaRotaryEmbedding(D)
|
||||||
|
cos, sin = emb(x0, TOTAL_TOKENS)
|
||||||
|
cos_2 = cos[:, : D // 2]
|
||||||
|
sin_2 = sin[:, : D // 2]
|
||||||
|
position_ids = torch.arange(TOTAL_TOKENS)
|
||||||
|
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
|
||||||
|
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
|
||||||
|
assert torch.allclose(embd_x0, embd_stimulated_x)
|
||||||
|
|
||||||
|
# create data
|
||||||
|
block_size = 32
|
||||||
|
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
|
||||||
|
q_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
|
||||||
|
k_shape = (TOTAL_TOKENS, H, D)
|
||||||
|
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
|
||||||
|
cos_shape = (TOTAL_TOKENS, D // 2)
|
||||||
|
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||||
|
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
|
||||||
|
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
|
||||||
|
v = torch.randn_like(k)
|
||||||
|
v_cache = torch.zeros_like(k_cache)
|
||||||
|
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
|
||||||
|
block_tables = mock_alloc_block_table_and_kvcache_v2(
|
||||||
|
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
|
||||||
|
)
|
||||||
|
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
|
||||||
|
new_q = torch.randn_like(new_k)
|
||||||
|
new_v = torch.randn_like(new_k)
|
||||||
|
|
||||||
|
kv_seq_lengths = past_kv_seq_lengths + 1
|
||||||
|
block_tables = block_tables.to(device="cuda")
|
||||||
|
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
|
||||||
|
|
||||||
|
new_q_copy = new_q.clone()
|
||||||
|
new_k_copy = new_k.clone()
|
||||||
|
|
||||||
|
inference_ops.rotary_embedding_and_cache_copy(
|
||||||
|
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
|
||||||
|
|
||||||
|
past_kv_seq_len = kv_seq_lengths - 1
|
||||||
|
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
|
||||||
|
offsets_in_block = past_kv_seq_len % block_size
|
||||||
|
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||||
|
k_source = new_k_copy.squeeze()
|
||||||
|
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
|
||||||
|
v_source = new_v.squeeze()
|
||||||
|
|
||||||
|
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert k_target.shape == k_source.shape
|
||||||
|
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
|
||||||
|
|
||||||
|
assert v_target.shape == v_source.shape
|
||||||
|
assert torch.equal(v_target, v_source)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_rotary_emb(16, 512, 4, 128, torch.float16)
|
33
tests/test_infer/test_ops/cuda/test_silu_and_mul.py
Normal file
33
tests/test_infer/test_ops/cuda/test_silu_and_mul.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.kernel.kernel_loader import InferenceOpsLoader
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
inference_ops = InferenceOpsLoader().load()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("SHAPE_X", [2])
|
||||||
|
@pytest.mark.parametrize("SHAPE_Y", [64])
|
||||||
|
@pytest.mark.parametrize("SHAPE_Z", [11008])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
|
||||||
|
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
|
||||||
|
torch.manual_seed(5)
|
||||||
|
device = get_current_device()
|
||||||
|
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
|
||||||
|
origin_input = ref_input.clone()
|
||||||
|
|
||||||
|
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
|
||||||
|
ref_out = act_out * ref_input[1]
|
||||||
|
|
||||||
|
origin_out = inference_ops.silu_and_mul(origin_input)
|
||||||
|
|
||||||
|
if dtype == torch.float32:
|
||||||
|
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_silu_and_mul(2, 64, 11008, torch.float32)
|
||||||
|
test_silu_and_mul(2, 64, 11008, torch.float16)
|
Loading…
Reference in New Issue
Block a user