[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)

* fused the gate and up proj in mlp

* fix code styles

* opt auto_grad

* rollback test_inference_engine.py

* modifications based on the review feedback.

* fix bugs in flash attn

* Change reshape to view

* fix test_rmsnorm_triton.py
This commit is contained in:
yuehuayingxueluo
2024-02-06 19:38:25 +08:00
committed by GitHub
parent 1dedb57747
commit 35382a7fbf
10 changed files with 484 additions and 50 deletions

View File

@@ -220,7 +220,7 @@ def flash_decoding_attention(
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
Returns:
Output tensor with shape [bsz, num_heads, q_len, head_dim]
Output tensor with shape [bsz, num_heads, head_dim]
"""
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
@@ -261,6 +261,8 @@ def flash_decoding_attention(
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
@@ -292,9 +294,7 @@ def flash_decoding_attention(
BLOCK_SIZE=block_size,
HEAD_DIM=head_dim,
)
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](