[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)

* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
This commit is contained in:
yuehuayingxueluo
2024-02-01 15:49:39 +08:00
committed by GitHub
parent f8e456d202
commit 249644c23b
8 changed files with 510 additions and 341 deletions

View File

@@ -143,8 +143,7 @@ def _flash_decoding_fwd_reduce_kernel(
stride_o_lset,
stride_o_lseh,
stride_o_lseb,
stride_ob,
stride_ol,
stride_ot,
stride_oh,
stride_od,
BLOCK_KV: tl.constexpr,
@@ -180,7 +179,7 @@ def _flash_decoding_fwd_reduce_kernel(
m_i = m_ij
acc = acc / l
offsets_O = cur_seq_idx * stride_ob + cur_head_idx * stride_oh + offsets_dmodel
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
tl.store(O + offsets_O, acc.to(O.type.element_ty))
return
@@ -212,7 +211,7 @@ def flash_decoding_attention(
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, 1, num_heads, head_dim]
output (torch.Tensor): [bsz, num_heads, head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
@@ -294,7 +293,7 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
output = torch.empty((bsz, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output
grid = (triton.next_power_of_2(bsz), num_heads)
@@ -314,7 +313,6 @@ def flash_decoding_attention(
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
BLOCK_KV=block_size,
HEAD_DIM=head_dim,
)