mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
Add padding llama model
This commit is contained in:
committed by
FrankLeeeee
parent
0e616462a7
commit
86853a37d5
@@ -183,13 +183,16 @@ class BatchInfo:
|
||||
|
||||
return cls(sequences_set=sequences_set)
|
||||
|
||||
def get_block_table_tensor(self):
|
||||
def get_block_table_tensor(self) -> None:
|
||||
tesnor_list = []
|
||||
block_table = None
|
||||
for seq in self.sequences_set:
|
||||
block_table = seq.block_table
|
||||
assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table."
|
||||
tesnor_list.append(seq.block_table)
|
||||
return torch.concat(tesnor_list)
|
||||
assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first."
|
||||
block_table = torch.concat(tesnor_list)
|
||||
return block_table
|
||||
|
||||
def clear_batch(self) -> None:
|
||||
"""
|
||||
@@ -271,3 +274,38 @@ class BatchInfo:
|
||||
Get batch_size of this batch
|
||||
"""
|
||||
return len(self.sequences_set)
|
||||
|
||||
def get_batch_inputs(self) -> torch.LongTensor:
|
||||
"""
|
||||
Get bacth inputs for forward inference computation.
|
||||
"""
|
||||
input_list = []
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.append(seq.input_token_id)
|
||||
else:
|
||||
input_list.append([seq.output_token_id[-1]])
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
|
||||
def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
return torch.tensor(input_list, dtype=torch.long)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
Get the input_len of each sentence in this batch.
|
||||
"""
|
||||
len_list = []
|
||||
for seq in self.sequences_set:
|
||||
len_list.append(seq.get_sentence_len())
|
||||
return torch.tensor(len_list, dtype=torch.int)
|
||||
|
Reference in New Issue
Block a user