[Inference]Add Nopadding Llama Modeling (#5327)

* add nopadding llama modeling

* add nopadding_llama.py

* rm unused codes

* fix bugs in test_xine_copy.py

* fix code style
This commit is contained in:
yuehuayingxueluo
2024-01-30 10:31:46 +08:00
committed by GitHub
parent c7c104cb7c
commit e8f0642f28
9 changed files with 386 additions and 49 deletions

View File

@@ -358,21 +358,16 @@ class BatchInfo:
Flattening the input tokens.
"""
input_list = []
input_len_list = []
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
for seq in self.sequences_set:
if self.is_prompts:
input_list.extend(seq.input_token_id)
input_len_list.append(seq.sentence_len)
else:
input_list.append(seq.output_token_id[-1])
input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
input_len_list, dtype=torch.int, device=self.device
)
return torch.tensor(input_list, dtype=torch.long, device=self.device)
def get_sequence_lengths(self):
"""
@@ -401,7 +396,9 @@ class BatchInfo:
past_values.append(seq.input_token_id + seq.output_token_id)
max_seq_len = max(len(sub_list) for sub_list in past_values)
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
attn_mask = _make_tensor_with_pad(
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
)
return attn_mask.ne(padding_id).long()