mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py
This commit is contained in:
@@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
stride_o_lset,
|
||||
stride_o_lseh,
|
||||
stride_o_lseb,
|
||||
stride_ob,
|
||||
stride_ol,
|
||||
stride_ot,
|
||||
stride_oh,
|
||||
stride_od,
|
||||
BLOCK_KV: tl.constexpr,
|
||||
@@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
|
||||
m_i = m_ij
|
||||
|
||||
acc = acc / l
|
||||
offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
|
||||
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
|
||||
tl.store(O + offsets_O, acc.to(O.type.element_ty))
|
||||
return
|
||||
|
||||
@@ -212,7 +211,7 @@ def flash_decoding_attention(
|
||||
records the (kv) sequence lengths incorporating past kv sequence lengths.
|
||||
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
|
||||
max_seq_len_in_batch (int): Maximum sequence length in the batch.
|
||||
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
|
||||
output (torch.Tensor): [bsz, num_heads, head_dim]
|
||||
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
|
||||
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
|
||||
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
|
||||
@@ -294,7 +293,7 @@ def flash_decoding_attention(
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
||||
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
|
||||
|
||||
grid = (triton.next_power_of_2(bsz), num_heads)
|
||||
|
||||
@@ -314,7 +313,6 @@ def flash_decoding_attention(
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
output.stride(3),
|
||||
BLOCK_KV=block_size,
|
||||
HEAD_DIM=head_dim,
|
||||
)
|
||||
|
Reference in New Issue
Block a user