mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[Inference]Repalce Attention layer and MLP layer by shardformer to optimize the weight transpose operation,add fused_qkv and fused linear_add (#5340)
* add fused qkv * replace attn and mlp by shardformer * fix bugs in mlp * add docstrings * fix test_inference_engine.py * add optimize unbind * add fused_addmm * rm squeeze(1) * refactor codes * fix ci bugs * rename ShardFormerLlamaMLP and ShardFormerLlamaAttention * Removed the dependency on LlamaFlashAttention2 * rollback test_inference_engine.py
This commit is contained in:
@@ -69,6 +69,7 @@ def torch_attn_ref(
|
||||
f"`attn_output` should be of size {(bsz, num_heads, seq_len, head_dim)}, but is" f" {out.size()}"
|
||||
)
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
out = out.squeeze(1)
|
||||
return out
|
||||
|
||||
|
||||
|
@@ -94,7 +94,7 @@ def test_flash_decoding(
|
||||
max_seq_len_in_b = kv_seq_lengths.max().item()
|
||||
# The maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
@@ -189,7 +189,7 @@ def bench_kernel(
|
||||
block_tables = block_tables.to(device=device)
|
||||
# the maximum block length splitted on kv should be the kv cache block size
|
||||
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
|
||||
output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
output = torch.empty((bsz, num_attn_heads, HEAD_DIM), dtype=dtype, device=device)
|
||||
mid_output = torch.empty(
|
||||
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
|
||||
)
|
||||
|
Reference in New Issue
Block a user