fix bugs in attention.py and request_handler.py

This commit is contained in:
yuehuayingxueluo
2024-01-08 12:35:06 +08:00
committed by FrankLeeeee
parent bfd9b1b494
commit 47e53eaa1c
6 changed files with 208 additions and 60 deletions

View File

@@ -321,5 +321,13 @@ class BatchInfo:
return torch.tensor(len_list, dtype=torch.int, device=self.device)
def get_attn_mask(self, padding_id: int) -> torch.Tensor:
past_values = []
for seq in self.sequences_set:
past_values.append(seq.input_token_id + seq.output_token_id)
return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long()
def __repr__(self) -> str:
return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})"