[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)

* [fix] GQA calling of flash decoding triton

* fix kv cache alloc shape

* fix rotary triton - GQA

* fix sequence max length assigning

* Sequence max length logic

* fix scheduling and spec-dec

* skip without import error

* fix pytest - skip without ImportError

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yuanheng Zhao
2024-04-23 13:09:55 +08:00
committed by GitHub
parent ccf72797e3
commit 5d4c1fe8f5
9 changed files with 183 additions and 194 deletions

View File

@@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
attn_qproj_w.dist_layout
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
else:
self.q_proj_weight = attn_qproj_w
self.k_proj_weight = attn_kproj_w
self.v_proj_weight = attn_vproj_w
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
@staticmethod
def from_native_module(
@@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)