mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[Inference]Fused the gate and up proj in mlp,and optimized the autograd process. (#5365)
* fused the gate and up proj in mlp * fix code styles * opt auto_grad * rollback test_inference_engine.py * modifications based on the review feedback. * fix bugs in flash attn * Change reshape to view * fix test_rmsnorm_triton.py
This commit is contained in:
@@ -220,7 +220,7 @@ def flash_decoding_attention(
|
||||
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape [bsz, num_heads, q_len, head_dim]
|
||||
Output tensor with shape [bsz, num_heads, head_dim]
|
||||
"""
|
||||
q = q.squeeze() if q.dim() == 4 else q
|
||||
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
|
||||
@@ -261,6 +261,8 @@ def flash_decoding_attention(
|
||||
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
|
||||
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
|
||||
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
|
||||
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
_flash_decoding_fwd_kernel[grid](
|
||||
q,
|
||||
k_cache,
|
||||
@@ -292,9 +294,7 @@ def flash_decoding_attention(
|
||||
BLOCK_SIZE=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
_flash_decoding_fwd_reduce_kernel[grid](
|
||||
|
Reference in New Issue
Block a user