[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)

* add fused qkv

* replace attn and mlp by shardformer

* fix bugs in mlp

* add docstrings

* fix test_inference_engine.py

* add optimize unbind

* add fused_addmm

* rm squeeze(1)

* refactor codes

* fix ci bugs

* rename ShardFormerLlamaMLP and ShardFormerLlamaAttention

* Removed the dependency on LlamaFlashAttention2

* rollback test_inference_engine.py
This commit is contained in:
yuehuayingxueluo
2024-02-01 15:49:39 +08:00
committed by GitHub
parent f8e456d202
commit 249644c23b
8 changed files with 510 additions and 341 deletions

View File

@@ -69,6 +69,7 @@ def torch_attn_ref(
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
)
out = out.transpose(1, 2).contiguous()
out = out.squeeze(1)
return out

View File

@@ -94,7 +94,7 @@ def test_flash_decoding(
max_seq_len_in_b = kv_seq_lengths.max().item()
# The maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)
@@ -189,7 +189,7 @@ def bench_kernel(
block_tables = block_tables.to(device=device)
# the maximum block length splitted on kv should be the kv cache block size
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
mid_output = torch.empty(
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
)