[Inference] Fix flash-attn import and add model test (#5794)

* Fix torch int32 dtype

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Fix flash-attn import

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add generalized model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Remove exposed path to model

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Add default value for use_flash_attn

Signed-off-by: char-1ee <xingjianli59@gmail.com>

* Rename model test

Signed-off-by: char-1ee <xingjianli59@gmail.com>

---------

Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
Li Xingjian
2024-06-12 14:13:50 +08:00
committed by GitHub
parent aac941ef78
commit 8554585a5f
7 changed files with 171 additions and 8 deletions

View File

@@ -200,8 +200,6 @@ class NopadBaichuanAttention(ParallelModule):
self.pre_attention_backend.decode(
attn_metadata,
cos=cos_sin[0],
sin=cos_sin[1],
q_len=q_len,
)
attn_output = self.attention_backend.decode(

View File

@@ -114,7 +114,7 @@ def llama_model_forward(
elif use_cuda_kernel:
if can_use_flash_attn2(inputmetadata.dtype):
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.int32), (1, 0))
hidden_dim = self._cos_cached.size(-1)
total_length = hidden_states.size(0)
@@ -265,7 +265,7 @@ class NopadLlamaMLP(LlamaMLP, ParallelModule):
mlp_dproj: ParallelModule = None,
process_group: ProcessGroup = None,
):
"""A Unified Layer for
"""Replacement of LlamaMLP layer.
Args:
config (LlamaConfig): Holding the Llama model config.