mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[Fix/Inference] Fix GQA Triton and Support Llama3 (#5624)
* [fix] GQA calling of flash decoding triton * fix kv cache alloc shape * fix rotary triton - GQA * fix sequence max length assigning * Sequence max length logic * fix scheduling and spec-dec * skip without import error * fix pytest - skip without ImportError --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -447,9 +447,9 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
else:
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
self.q_proj_weight = nn.Parameter(attn_qproj_w.transpose(0, 1).contiguous())
|
||||
self.k_proj_weight = nn.Parameter(attn_kproj_w.transpose(0, 1).contiguous())
|
||||
self.v_proj_weight = nn.Parameter(attn_vproj_w.transpose(0, 1).contiguous())
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
@@ -638,6 +638,7 @@ class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
mid_output=fd_inter_tensor.mid_output,
|
||||
mid_output_lse=fd_inter_tensor.mid_output_lse,
|
||||
sm_scale=sm_scale,
|
||||
kv_group_num=self.num_key_value_groups,
|
||||
q_len=q_len,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user