mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[Inference]Adapted to the triton attn kernels (#5264)
* adapted to the triton attn kernels * fix pad input * adapted to copy_kv_to_blocked_cache * fix ci test * update kv memcpy * remove print
This commit is contained in:
@@ -332,12 +332,20 @@ class BatchInfo:
|
||||
return torch.tensor(len_list, dtype=torch.int, device=self.device)
|
||||
|
||||
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Generate and return attention mask.
|
||||
"""
|
||||
past_values = []
|
||||
|
||||
for seq in self.sequences_set:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
||||
attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
|
||||
|
||||
if torch.any(attn_mask == 0):
|
||||
return attn_mask
|
||||
else:
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"
|
||||
|
Reference in New Issue
Block a user