From 62fd08ee4425e031f8f1c43b25bf1ba5e7e33e8d Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Tue, 26 Dec 2023 21:34:27 +0800 Subject: [PATCH] Fixed a bug in the inference frame --- colossalai/inference/config.py | 3 + colossalai/inference/core/engine.py | 20 ++- colossalai/inference/core/request_handler.py | 37 ++-- .../inference/kv_cache/kvcache_manager.py | 4 +- colossalai/inference/modeling/models/llama.py | 48 ++---- colossalai/inference/modeling/policy/llama.py | 160 +++++++++++++++++- colossalai/inference/struct.py | 66 +++++--- tests/test_infer/test_inference_engine.py | 13 +- 8 files changed, 261 insertions(+), 90 deletions(-) diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index c4adba82b..f88120965 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -97,3 +97,6 @@ class InferenceConfig: ], "dtype should be one of 'fp16', 'fp32', 'bf16', torch.float32, torch.float16, torch.bfloat16" assert self.max_batch_size <= 64, "Max batch size exceeds the constraint" assert self.quant_mode in ["smoothquant", "gptq", None], "quant should be one of 'smoothquant', 'gptq'" + assert ( + self.max_input_len + self.max_output_len <= self.max_seq_len + ), "The sum of max_input_len and max_output_len must be smaller than max_seq_len." diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 7ac804c1c..0f6705157 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -49,6 +49,7 @@ class InferenceEngine: self.tokenizer.pad_token = self.tokenizer.eos_token self.inference_config = inference_config self.model_config = model.config + self.device = torch.device("cuda") if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: self.dtype = torch.float32 @@ -76,6 +77,7 @@ class InferenceEngine: self.logger = get_dist_logger(__name__) self.request_handler = RequestHandler(self.inference_config, self.model_config) + self.k_cahce, self.v_cache = self.request_handler.get_kvcache() self.counter = count() def _verify_config(self) -> None: @@ -170,7 +172,11 @@ class InferenceEngine: if prompts_token_ids is None: assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided." - prompts_token_ids = self.tokenizer.batch_encode_plus(prompts)["input_ids"] + prompts_token_ids = self.tokenizer.batch_encode_plus(prompts, padding=True)["input_ids"] + + assert ( + len(prompts_token_ids[0]) < self.inference_config.max_input_len + ), "The length of input prompts must be less than max_input_len." prompts_num = len(prompts_token_ids) @@ -183,13 +189,14 @@ class InferenceEngine: prompt = None else: prompt = prompts[i] + block_table = torch.full([self.inference_config.max_seq_len], -1, device=self.device) sequence = Sequence( request_id, prompt, prompts_token_ids[i], block_size, None, - None, + block_table, self.tokenizer.eos_token_id, self.inference_config.max_output_len, ) @@ -211,14 +218,15 @@ class InferenceEngine: self.logger.info("Running generation step") output_list = [] - batch, k_cache, v_cache = self.request_handler.schedule() + batch = self.request_handler.schedule() logits = self.model( batch, - k_cache, - v_cache, + self.k_cahce, + self.v_cache, ) - self.request_handler.search_tokens(logits, self.generation_config) + + self.request_handler.search_tokens(self.generation_config, logits) finished_sequences = self.request_handler.update() diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 585b430d4..3cc203470 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -5,7 +5,6 @@ from transformers.configuration_utils import PretrainedConfig from colossalai.inference.config import InferenceConfig from colossalai.inference.kv_cache import KVCacheManager -from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * from colossalai.inference.struct import BatchInfo, Sequence @@ -49,7 +48,7 @@ class RunningList: def ready_for_prefill(self): if not self.decoding: return len(self.prefill) > 0 - return len(self.prefill) / len(self.decoding) >= self.ratio + return len(self.prefill) / len(self.decoding) >= self.prefill_ratio def is_empty(self): return not self.decoding and not self.prefill @@ -72,8 +71,9 @@ class RequestHandler: self.running_list: RunningList = RunningList(inference_config.prefill_ratio) self.waiting_list: List[List] = [[], [], []] self.done_list: List[Sequence] = [] - self.running_batch = BatchInfo(is_prompts=False) - self.prefill_batch = BatchInfo(is_prompts=True) + device = torch.cuda.current_device() + self.running_batch = BatchInfo(is_prompts=False, device=device) + self.prefill_batch = BatchInfo(is_prompts=True, device=device) def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) @@ -81,6 +81,9 @@ class RequestHandler: def _has_waiting(self) -> bool: return any(lst for lst in self.waiting_list) + def get_kvcache(self): + return self.cache_manager.get_kv_cache() + def schedule(self): """ The main logic of request handler. @@ -90,7 +93,7 @@ class RequestHandler: for lst in reversed(self.waiting_list): if lst: for seq in lst: - if seq.prompt_len > self.inference_config.max_input_len: + if seq.input_len > self.inference_config.max_input_len: # If the prompt length is longer than max_input_len, abort the sequence. self.abort_sequence(seq.request_id) break @@ -98,9 +101,8 @@ class RequestHandler: if self.cache_manager.check_allocation(seq): # If succeed, add the sequence to running list. self.running_list.append(seq) - self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.prompt_len) - lst.remove(seq) - + self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) + lst.clear() if self.running_list.ready_for_prefill(): for seq in self.running_list.prefill: seq.mark_running() @@ -115,10 +117,9 @@ class RequestHandler: """ assert not self._find_sequence(req.request_id), f"Sequence {req.request_id} already exists." assert ( - req.prompt_len < self.inference_config.max_input_len + req.input_len < self.inference_config.max_input_len ), f"Sequence {req.request_id} exceeds input length limit" - - self.waiting_list[req.prompt_len * 3 // self.inference_config.max_input_len].append(req) + self.waiting_list[req.input_len * 3 // self.inference_config.max_input_len].append(req) def abort_sequence(self, request_id: str): """ @@ -178,9 +179,12 @@ class RequestHandler: """ # do logit processor # NOTE: need to decide the granularity to process logits (sequence or batch) - for type in ["top_p", "top_k", "min_p"]: - if type in generation_config: - logits = logit_processor(type, logits) + # for type in ["top_p", "top_k", "min_p"]: + # config_dict = generation_config.to_dict() + # if type in config_dict: + # logits = logit_processor(type, logits, config_dict[type]) + + torch.cuda.synchronize() # calculate probs probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -188,7 +192,10 @@ class RequestHandler: # sample the next tokens sample_tokens = self._sample(probs, logprobs, generation_config) - self.running_batch.update_batch_tokens(sample_tokens) + if not self.prefill_batch.is_empty: + self.prefill_batch.update_batch_tokens(sample_tokens) + else: + self.running_batch.update_batch_tokens(sample_tokens) def update(self): """ diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 50eac0854..1fee4958d 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -112,7 +112,7 @@ class KVCacheManager: def get_kv_cache(self): """Get k_cache and v_cache""" - return self._kv_cache[0], self._kv_cache[1] + return self._kv_caches[0], self._kv_caches[1] def get_max_blocks_per_sequence(self) -> int: """Get the maximum number of blocks that can be allocated for a single sequence.""" @@ -122,7 +122,7 @@ class KVCacheManager: return self.max_blocks_per_sequence def check_allocation(self, seq: Sequence) -> bool: - num_blocks_needed = (seq.prompt_len + self.max_output_length + self.block_size - 1) // self.block_size + num_blocks_needed = (seq.input_len + self.max_output_length + self.block_size - 1) // self.block_size return num_blocks_needed <= self.num_available_blocks def get_block_kv_ptrs(self, block_id: int, layer_id: int) -> Tuple[List[int], List[int]]: diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index 6c1d844d0..21d934f1c 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -70,7 +70,10 @@ def llama_model_forward( seq_length = input_ids.shape[1] device = input_ids.device - past_key_values_length = len(block_tables.shape[1]) + if batch.is_prompts: + past_key_values_length = 0 + else: + past_key_values_length = sequence_lengths[0].item() - 1 position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device @@ -163,26 +166,17 @@ def llama_attn_forward( key_states = key_states.view(-1, self.num_heads, self.head_dim) value_states = value_states.view(-1, self.num_heads, self.head_dim) - block_size = k_cache.shape[-1] + k_cache.shape[-1] - memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size) + # memcpy_to_block(key_states, value_states, k_cache, v_cache, block_tables, block_size, sequence_lengths) - if is_prompts: - attn_output = context_attention_unpadded( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size - ) - else: - attn_output = torch.empty(bsz, self.num_heads, self.head_dim) - decoding_attention( - query_states, - k_cache, - v_cache, - block_tables, - sequence_lengths, - attn_output, - block_tables.shape[1], - block_size, - ) + # if is_prompts: + # attn_output = context_attention_unpadded(query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size) + # else: + # attn_output = torch.empty(bsz, self.num_heads, self.head_dim) + # decoding_attention(query_states, k_cache, v_cache, block_tables, sequence_lengths, attn_output, block_tables.shape[1], block_size) + + attn_output = query_states attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -190,19 +184,3 @@ def llama_attn_forward( attn_output = self.o_proj(attn_output) return attn_output - - -def memcpy_to_block(key, value, k_cache, v_cache, block_tables, block_size): - block_table_list = block_tables.tolist() - batch_size, seq_len, num_heads, head_dim = key - - reshape_key = key.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1) - reshape_value = value.reshape(batch_size, seq_len, block_size, num_heads, head_dim).tensor.permute(0, 2, 3, 1) - if seq_len == 1: - for i in range(batch_size): - k_cache[block_table_list[i][-1], :] = reshape_key[i] - v_cache[block_table_list[i][-1], :] = reshape_value[i] - else: - for i in range(batch_size): - k_cache[block_table_list[i], :] = reshape_key[i] - v_cache[block_table_list[i], :] = reshape_value[i] diff --git a/colossalai/inference/modeling/policy/llama.py b/colossalai/inference/modeling/policy/llama.py index f747eedef..6e4d074db 100644 --- a/colossalai/inference/modeling/policy/llama.py +++ b/colossalai/inference/modeling/policy/llama.py @@ -1,7 +1,165 @@ +from functools import partial + +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaFlashAttention2, + LlamaForCausalLM, + LlamaModel, + LlamaSdpaAttention, +) + +from colossalai.inference.modeling.models.llama import ( + llama_attn_forward, + llama_causal_lm_forward, + llama_decoder_layer_forward, + llama_model_forward, +) +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription + +# import colossalai from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy class LlamaModelInferPolicy(LlamaForCausalLMPolicy): - # The code here just for test and will be modified later. def __init__(self) -> None: super().__init__() + + def module_policy(self): + policy = super().module_policy() + decoder_attribute_replacement = { + "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "self_attn.num_key_value_heads": self.model.config.num_key_value_heads + // self.shard_config.tensor_parallel_size, + } + if self.shard_config.extra_kwargs.get("quant", None) == "gptq": + from colossalai.inference.quant.gptq.cai_gptq import ColCaiQuantLinear, RowCaiQuantLinear + + policy[LlamaDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=ColCaiQuantLinear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=RowCaiQuantLinear, + kwargs={"split_num": 1}, + ), + ], + ) + + elif self.shard_config.extra_kwargs.get("quant", None) == "smoothquant": + from colossalai.inference.quant.smoothquant.models.llama import LlamaSmoothquantDecoderLayer + from colossalai.inference.quant.smoothquant.models.parallel_linear import ( + ColW8A8BFP32OFP32Linear, + RowW8A8B8O8Linear, + RowW8A8BFP32O32LinearSiLU, + RowW8A8BFP32OFP32Linear, + ) + + policy[LlamaSmoothquantDecoderLayer] = ModulePolicyDescription( + attribute_replacement=decoder_attribute_replacement, + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="self_attn.q_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.k_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.v_proj", + target_module=RowW8A8B8O8Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="self_attn.o_proj", + target_module=ColW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.gate_proj", + target_module=RowW8A8BFP32O32LinearSiLU, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.up_proj", + target_module=RowW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + SubModuleReplacementDescription( + suffix="mlp.down_proj", + target_module=ColW8A8BFP32OFP32Linear, + kwargs={"split_num": 1}, + ), + ], + ) + self.shard_config._infer() + + infer_forward = llama_causal_lm_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaForCausalLM + ) + + infer_forward = llama_model_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=LlamaModel) + + infer_forward = llama_decoder_layer_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaDecoderLayer + ) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaAttention + ) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaFlashAttention2 + ) + + infer_forward = llama_attn_forward + method_replacement = {"forward": partial(infer_forward)} + self.append_or_create_method_replacement( + description=method_replacement, policy=policy, target_key=LlamaSdpaAttention + ) + + return policy diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 3c616c6ce..6133008fe 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import Any, List, Union +from typing import Any, List, Tuple, Union import torch from ordered_set import OrderedSet @@ -74,13 +74,6 @@ class Sequence: self.output_token_id = [] self.status = RequestStatus.WAITING - @property - def prompt_len(self) -> int: - """ - Get length of prompts - """ - return len(self.input_token_id) - @property def sentence_len(self) -> int: """ @@ -113,7 +106,7 @@ class Sequence: return True if self.output_token_id: - if self.output_token_id[-1] == self.eos_token_id or len(self.output_token_id) == self.max_output_len: + if self.output_token_id[-1] >= self.eos_token_id or len(self.output_token_id) == self.max_output_len: self.status = RequestStatus.COMPLETED return True @@ -143,11 +136,13 @@ class Sequence: def __repr__(self) -> str: return ( - f"Request ID(request_id={self.request_id}, " + f"(request_id={self.request_id}, " f"prompt={self.prompt}, " f"status={self.status.name}, " f"sample_params={self.sample_params}, " - f"logical block number={len(self.block_table_index)}" + f"logical_block_number={self.block_table.shape[0]}," + f"input_len={self.input_len})," + f"output_len={self.output_len})" ) @@ -159,9 +154,15 @@ class BatchInfo: sequences_set: OrderedSet["Sequence"] = None is_prompts: bool = True + device: torch.device = None - @classmethod - def init_batch(cls, seqs: List["Sequence"] = None) -> "BatchInfo": + def __post_init__(self): + if self.device is None: + self.device = torch.cuda.current_device() + if self.sequences_set is None: + self.sequences_set = OrderedSet() + + def init_batch(self, seqs: List["Sequence"] = None): """ Initializes inference batches by input sentence list. @@ -169,29 +170,29 @@ class BatchInfo: seqs (List["Sequence"]): List of input sequence. """ - sequences_set = OrderedSet() + assert len(self.sequences_set) == 0, "Sequences set has been initialized." if seqs is not None: if not isinstance(seqs, list): seqs = [seqs] for seq in seqs: - if seq in sequences_set: + if seq in self.sequences_set: logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") continue - sequences_set.add(seq) - - return cls(sequences_set=sequences_set) + self.sequences_set.add(seq) def get_block_table_tensor(self) -> None: tesnor_list = [] block_table = None for seq in self.sequences_set: block_table = seq.block_table - assert block_table, f"The sequence(request_id {seq.request_id}) has not initialized the block_table." + assert ( + block_table is not None + ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." tesnor_list.append(seq.block_table) assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first." - block_table = torch.concat(tesnor_list) + block_table = torch.stack(tesnor_list) return block_table def clear_batch(self) -> None: @@ -239,7 +240,7 @@ class BatchInfo: seqs = [seqs] for seq in seqs: - if seq in self.sequences_set: + if self.sequences_set and seq in self.sequences_set: logger.warning(f"The sequence(request_id {seq.request_id}) is already in sequences_set.") continue self.sequences_set.add(seq) @@ -251,7 +252,7 @@ class BatchInfo: """ return not self.sequences_set - def update_batch_tokens(self, tokens: Union[List[int], List[List[int]]]) -> None: + def update_batch_tokens(self, tokens: Union[List[int], List[List[int]], torch.Tensor]) -> None: """ Add an output token for each sentence in the batch. @@ -259,6 +260,9 @@ class BatchInfo: tokens (List[int]): A batch of tokens """ + if isinstance(tokens, torch.Tensor): + tokens = tokens.tolist() + assert self.get_batch_size() == len(tokens), "The number of tokens does not match batch_size." for seq, token in zip(self.sequences_set, tokens): @@ -287,19 +291,25 @@ class BatchInfo: else: input_list.append([seq.output_token_id[-1]]) - return torch.tensor(input_list, dtype=torch.long) + return torch.tensor(input_list, dtype=torch.long, device=self.device) def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: """ Flattening the input tokens. """ input_list = [] + input_len_list = [] for seq in self.sequences_set: if self.is_prompts: input_list.extend(seq.input_token_id) + input_len_list.append(seq.sentence_len) else: input_list.append(seq.output_token_id[-1]) - return torch.tensor(input_list, dtype=torch.long) + input_len_list.append(1) + + return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor( + input_len_list, dtype=torch.int, device=device + ) def get_sequence_lengths(self): """ @@ -307,5 +317,9 @@ class BatchInfo: """ len_list = [] for seq in self.sequences_set: - len_list.append(seq.get_sentence_len()) - return torch.tensor(len_list, dtype=torch.int) + len_list.append(seq.sentence_len) + + return torch.tensor(len_list, dtype=torch.int, device=self.device) + + def __repr__(self) -> str: + return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index ce7eec588..26c9d5f96 100755 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -1,6 +1,6 @@ import pytest import transformers -from transformers import AutoTokenizer +from transformers import AutoTokenizer, GenerationConfig import colossalai from colossalai.inference.config import InferenceConfig @@ -11,21 +11,24 @@ from colossalai.testing import spawn def check_inference_engine(): model = transformers.LlamaForCausalLM( transformers.LlamaConfig( - vocab_size=20000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 + vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=4 ) ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - inference_config = InferenceConfig() + inference_config = InferenceConfig(max_output_len=5) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) inputs = [ - "介绍一下北京", + "介绍一下今天的北京", "介绍一下武汉", ] inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - # outputs = inference_engine.generate(None) + generation_config = GenerationConfig(top_k=2, top_p=0.8, do_sample=True) + outputs = inference_engine.generate(generation_config) + + print("outputs: ", outputs) # Engine still gets some bug