Optimized the execution interval time between cuda kernels caused by view and memcopy (#5390)

* opt_view_and_memcopy

* fix bugs in ci

* fix ci bugs

* update benchmark scripts

* fix ci bugs
This commit is contained in:
yuehuayingxueluo
2024-02-21 13:23:57 +08:00
committed by GitHub
parent 730103819d
commit 2a718c8be8
8 changed files with 141 additions and 55 deletions

View File

@@ -205,7 +205,7 @@ def context_attention_unpadded(
assert k_cache.shape == v_cache.shape
assert context_lengths.shape[0] == block_tables.shape[0]
num_tokens, num_heads, _ = q.shape
num_tokens, num_heads, head_dim = q.shape
num_kv_heads = k.shape[-2]
assert num_kv_heads > 0 and num_heads % num_kv_heads == 0
num_kv_group = num_heads // num_kv_heads
@@ -213,7 +213,9 @@ def context_attention_unpadded(
num_seqs, max_blocks_per_seq = block_tables.shape
max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len
sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale
output = torch.zeros_like(q) if output is None else output
output = (
torch.empty((num_tokens, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
)
# NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with
# the size of physical cache block (i.e. `block_size`)
@@ -243,8 +245,8 @@ def context_attention_unpadded(
v.stride(1),
v.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
head_dim,
1,
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),

View File

@@ -211,7 +211,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, num_heads, head_dim]
output (torch.Tensor): [bsz, num_heads * head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@@ -220,7 +220,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
Returns:
Output tensor with shape [bsz, num_heads, head_dim]
Output tensor with shape [bsz, num_heads * head_dim]
"""
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
@@ -261,7 +261,7 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
_flash_decoding_fwd_kernel[grid](
q,
@@ -294,7 +294,7 @@ def flash_decoding_attention(
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
grid = (triton.next_power_of_2(bsz), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](
@@ -311,8 +311,8 @@ def flash_decoding_attention(
mid_output_lse.stride(1),
mid_output_lse.stride(2),
output.stride(0),
output.stride(1),
output.stride(2),
head_dim,
1,
BLOCK_KV=block_size,
HEAD_DIM=head_dim,
)

View File

@@ -49,7 +49,50 @@ if HAS_TRITON:
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rms_layernorm(x, weight, eps, norm_output=None):
@triton.jit
def _rmsnorm_with_residual_kernel(
X, # pointer to the input
Y, # pointer to the output
R, # pointer to the residual
W, # pointer to the weights
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# This triton kernel implements Root Mean Square Layer Norm (RMSNorm).
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
Y += row * stride
X += row * stride
R += row * stride
# Compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x, 0.0)
r = tl.load(R + cols, mask=cols < N, other=0.0).to(tl.float32)
r = tl.where(cols < N, r, 0.0)
x = x + r
_var += x * x
mask = cols < N
tl.store(X + cols, x.to(tl.float16), mask=mask)
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# Normalize and apply linear transformation
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
def rms_layernorm(x, weight, eps, norm_output=None, residual=None):
# allocate output
y = torch.empty_like(x) if norm_output is None else norm_output
M, N = x.shape
@@ -64,5 +107,10 @@ if HAS_TRITON:
num_warps = min(max(triton.next_power_of_2(N) // 256, 8), 32)
# enqueue kernel
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
return y
if residual is None:
_rmsnorm_kernel[(M,)](x, y, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
else:
_rmsnorm_with_residual_kernel[(M,)](
x, y, residual, weight, x.stride(0), N, eps, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
return y, x