mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[Inference] Fix flash-attn import and add model test (#5794)
* Fix torch int32 dtype Signed-off-by: char-1ee <xingjianli59@gmail.com> * Fix flash-attn import Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add generalized model test Signed-off-by: char-1ee <xingjianli59@gmail.com> * Remove exposed path to model Signed-off-by: char-1ee <xingjianli59@gmail.com> * Add default value for use_flash_attn Signed-off-by: char-1ee <xingjianli59@gmail.com> * Rename model test Signed-off-by: char-1ee <xingjianli59@gmail.com> --------- Signed-off-by: char-1ee <xingjianli59@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ def prepare_data(
|
||||
num_tokens = torch.sum(context_lengths).item()
|
||||
|
||||
max_seq_len_in_batch = context_lengths.max()
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.int32), (1, 0))
|
||||
|
||||
kv_size = (num_tokens, num_kv_heads, HEAD_DIM)
|
||||
key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
|
||||
|
Reference in New Issue
Block a user