[Inference]Adapted to the triton attn kernels (#5264)

* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
This commit is contained in:
yuehuayingxueluo
2024-01-17 16:03:10 +08:00
committed by GitHub
parent 0f2b46a41c
commit 86b63f720c
7 changed files with 221 additions and 101 deletions

View File

@@ -332,12 +332,20 @@ class BatchInfo:
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
"""
Generate and return attention mask.
"""
past_values = []
for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id)
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
if torch.any(attn_mask == 0):
return attn_mask
else:
return None
def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"