diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 99d6b3b85..730a358cd 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -57,6 +57,9 @@ class RunningList: def is_empty(self): return not self.decoding and not self.prefill + def total_seq_num(self): + return len(self.decoding) + len(self.prefill) + class RequestHandler: """ @@ -105,7 +108,13 @@ class RequestHandler: f"the prompt(Request id = {seq.request_id}) length is longer than max_input_len, abort this sequence." ) self.abort_sequence(seq.request_id) + remove_list.append(seq) break + + # stop feeding new sequence into running list to assure + if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num(): + break + # Try to allocate cache blocks for the sequence. if ( self.cache_manager.check_allocation(seq) @@ -115,7 +124,7 @@ class RequestHandler: # If succeed, add the sequence to running list. remove_list.append(seq) self.running_list.append(seq) - self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) + self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.sentence_len) for seq in remove_list: lst.remove(seq) if self.running_list.ready_for_prefill(): @@ -126,7 +135,13 @@ class RequestHandler: if not self.running_batch.is_empty: for seq in self.running_batch.sequences_set: - self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) + recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) + if recycle: + seq.recycle() + self.running_batch.del_seq(seq) + self.running_list.remove(seq) + self.waiting_list[-1].append(seq) + # the recycled sequences are handled with highest priority. return self.running_batch diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index 41e50f40d..7fc9d1553 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -69,7 +69,7 @@ def convert_kvcache(cache, lengths, block_tables, pad_id=0): ) padding = seq_len - _cache.size(0) if padding > 0: - _cache = F.pad(_cache, (0, 0, 0, 0, 0, 1), value=pad_id) + _cache = F.pad(_cache, (0, 0, 0, 0, 0, padding), value=pad_id) padded_cache.append(_cache) return torch.stack(padded_cache, dim=0) diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index bbdb2f407..f3cfb3860 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -173,7 +173,10 @@ def llama_attn_forward( key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = max(sequence_lengths).item() + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 54560d046..05ab72bf4 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -29,6 +29,9 @@ class RequestStatus(enum.Enum): COMPLETED = enum.auto() LENGTH_CAPPED = enum.auto() + # recycle status + RECYCLED = enum.auto() + @staticmethod def is_finished(status: "RequestStatus") -> bool: return status in [ @@ -119,7 +122,9 @@ class Sequence: """ Set status for prefill reqs. """ - assert self.status == RequestStatus.WAITING, "Sequence is not in WAITTING STATUS" + assert ( + self.status == RequestStatus.WAITING or RequestStatus.RECYCLED + ), "Sequence is not in WAITTING/RECYCLED STATUS" self.status = RequestStatus.RUNNING def mark_finished(self) -> None: @@ -139,10 +144,10 @@ class Sequence: Recycle a running sequnce to waiitting list """ assert ( - not self.status.is_finished and not self.status == RequestStatus.ABORTED + not self.check_finish() and not self.status == RequestStatus.ABORTED ), "The running sequence \ is already done but it still in running list" - self.status = RequestStatus.WAITING + self.status = RequestStatus.RECYCLED def __repr__(self) -> str: return ( @@ -162,7 +167,7 @@ class BatchInfo: Information to be passed and used for a batch of sequences. """ - sequences_set: OrderedSet["Sequence"] = None + sequences_set: OrderedSet[Sequence] = None is_prompts: bool = True device: torch.device = None @@ -207,12 +212,20 @@ class BatchInfo: def clear_batch(self) -> None: """ - Clear sequence set and block table. + Clear sequence set and block table if we need to abort this batch. + Prefill: clear sequence set and move them to running batch(external) + Decoding: mark unfinished sequences as aborted. """ - for seq in self.sequences_set: - if not seq.check_finish(): - seq.status = RequestStatus.ABORTED - self.sequences_set.clear() + if self.is_prompts: + self.sequences_set.clear() + + else: + for seq in self.sequences_set: + seq.mark_aborted() + if seq.check_finish(): + seq.mark_finished() + + self.sequences_set.clear() def fliter_batch(self) -> List["Sequence"]: """ @@ -255,6 +268,12 @@ class BatchInfo: continue self.sequences_set.add(seq) + def del_seq(self, seq: Sequence) -> Sequence: + """ + Delete sequence in batch + """ + self.sequences_set.discard(seq) + @property def is_empty(self) -> None: """ @@ -297,11 +316,19 @@ class BatchInfo: for seq in self.sequences_set: if self.is_prompts: - input_list.append(seq.input_token_id) + if seq.output_len > 0: + print(seq.output_token_id) + seq_data = seq.input_token_id + seq.output_token_id + print(seq_data) + input_list.append(seq.input_token_id + seq.output_token_id) + else: + input_list.append(seq.input_token_id) else: input_list.append([seq.output_token_id[-1]]) - return torch.tensor(input_list, dtype=torch.long, device=self.device) + max_seq_len = max(len(sub_list) for sub_list in input_list) + + return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int) def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: """ @@ -340,12 +367,27 @@ class BatchInfo: for seq in self.sequences_set: past_values.append(seq.input_token_id + seq.output_token_id) - attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long() + max_seq_len = max(len(sub_list) for sub_list in past_values) + attn_mask = _make_tensor_with_pad(past_values, max_seq_len, 0, dtype=torch.int, device=self.device) - if torch.any(attn_mask == 0): - return attn_mask - else: - return None + return attn_mask.ne(padding_id).long() def __repr__(self) -> str: return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return x + [pad] * (max_len - len(x)) + + +def _make_tensor_with_pad( + x: Union[List[List[int]], List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, +): + padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] + return torch.tensor(padded_x, dtype=dtype, device=device, pin_memory=pin_memory and str(device) == "cpu") diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 2b3733c61..457546a7f 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -95,11 +95,10 @@ def benchmark_inference(args): if args.dtype == "fp16": model = model.half() - elif args.dtype == "bf16": + elif args.dtype == "fp16": model = model.to(torch.bfloat16) - # mbsz = args.mbsz - mbsz = args.batch_size + mbsz = args.mbsz if args.mode == "caiinference": inference_config = InferenceConfig( dtype=args.dtype, diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index a89776b6e..348cd5d21 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -2,7 +2,7 @@ import pytest import colossalai from colossalai.inference.config import InferenceConfig -from colossalai.inference.struct import BatchInfo, Sequence +from colossalai.inference.struct import BatchInfo, RequestStatus, Sequence from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -41,6 +41,10 @@ def check_config_and_inference(): eos_token_id=2, max_output_len=256, ) + sequence.mark_running() + assert sequence.status == RequestStatus.RUNNING + sequence.recycle() + assert sequence.status == RequestStatus.RECYCLED assert sequence.sentence_len == 3 assert sequence.input_len == 3