Add padding llama model

This commit is contained in:
yuehuayingxueluo
2023-12-25 14:07:43 +08:00
committed by FrankLeeeee
parent 0e616462a7
commit 86853a37d5
5 changed files with 262 additions and 11 deletions

View File

@@ -183,13 +183,16 @@ class BatchInfo:
return cls(sequences_set=sequences_set)
def get_block_table_tensor(self):
def get_block_table_tensor(self) -> None:
tesnor_list = []
block_table = None
for seq in self.sequences_set:
block_table = seq.block_table
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
tesnor_list.append(seq.block_table)
return torch.concat(tesnor_list)
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
block_table = torch.concat(tesnor_list)
return block_table
def clear_batch(self) -> None:
"""
@@ -271,3 +274,38 @@ class BatchInfo:
Get batch_size of this batch
"""
return len(self.sequences_set)
def get_batch_inputs(self) -> torch.LongTensor:
"""
Get bacth inputs for forward inference computation.
"""
input_list = []
for seq in self.sequences_set:
if self.is_prompts:
input_list.append(seq.input_token_id)
else:
input_list.append([seq.output_token_id[-1]])
return torch.tensor(input_list, dtype=torch.long)
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
"""
Flattening the input tokens.
"""
input_list = []
for seq in self.sequences_set:
if self.is_prompts:
input_list.extend(seq.input_token_id)
else:
input_list.append(seq.output_token_id[-1])
return torch.tensor(input_list, dtype=torch.long)
def get_sequence_lengths(self):
"""
Get the input_len of each sentence in this batch.
"""
len_list = []
for seq in self.sequences_set:
len_list.append(seq.get_sentence_len())
return torch.tensor(len_list, dtype=torch.int)