mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[Inference]Add Nopadding Llama Modeling (#5327)
* add nopadding llama modeling * add nopadding_llama.py * rm unused codes * fix bugs in test_xine_copy.py * fix code style
This commit is contained in:
@@ -358,21 +358,16 @@ class BatchInfo:
|
||||
Flattening the input tokens.
|
||||
"""
|
||||
input_list = []
|
||||
input_len_list = []
|
||||
|
||||
assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first."
|
||||
|
||||
for seq in self.sequences_set:
|
||||
if self.is_prompts:
|
||||
input_list.extend(seq.input_token_id)
|
||||
input_len_list.append(seq.sentence_len)
|
||||
else:
|
||||
input_list.append(seq.output_token_id[-1])
|
||||
input_len_list.append(1)
|
||||
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
|
||||
input_len_list, dtype=torch.int, device=self.device
|
||||
)
|
||||
return torch.tensor(input_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def get_sequence_lengths(self):
|
||||
"""
|
||||
@@ -401,7 +396,9 @@ class BatchInfo:
|
||||
past_values.append(seq.input_token_id + seq.output_token_id)
|
||||
|
||||
max_seq_len = max(len(sub_list) for sub_list in past_values)
|
||||
attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device)
|
||||
attn_mask = _make_tensor_with_pad(
|
||||
past_values, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
return attn_mask.ne(padding_id).long()
|
||||
|
||||
|
Reference in New Issue
Block a user