[Inference] Add the logic of the inference engine (#5173)

* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct

* Add the logic of the inference engine

* update engine and test

* Recover cache_manager.py

* add logger

* fix conflict

* update codes

* update codes

* update model and tokenizer

* fix add the logic about shardformer

* change kvcache_manager docstring

* add policy

* fix ci bug in test_kvcache_manager.py

* remove codes related o tokenizer and move model_policy

* fix  code style

* add ordered_set to requirements-infer.txt

* Delete extra empty lines

* add ordered_set to requirements-test.txt
This commit is contained in:
yuehuayingxueluo
2023-12-18 10:40:47 +08:00
committed by FrankLeeeee
parent 93aeacca34
commit 8daee26989
13 changed files with 555 additions and 172 deletions

View File

@@ -1,65 +1,232 @@
from logging import Logger
from typing import Optional
from itertools import count
from typing import List, Optional, Union
from transformers import AutoConfig
import torch
import torch.nn as nn
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
from colossalai.cluster import ProcessGroupMesh
from colossalai.inference.config import InferenceConfig
from colossalai.inference.modeling.policy import model_policy_map
from colossalai.inference.struct import Sequence
from colossalai.logging import get_dist_logger
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from .request_handler import RequestHandler
PP_AXIS, TP_AXIS = 0, 1
_supported_models = [
"LlamaForCausalLM",
]
class InferenceEngine:
"""
InferenceEngine is the core component for Inference.
It is responsible for launch the inference process, including:
- Initialize model and distributed training environment(if needed)
- Launch request_handler and corresponding kv cache manager
- Receive requests and generate texts.
- Log the generation process
"""
InferenceEngine which manages the inference process..
Args:
tokenizer: Path of the tokenizer to use.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
model (nn.Module): Path or nn.Module of this model.
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
verbose (bool): Determine whether or not to log the generation process.
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
"""
def __init__(
self,
tokenizer: str = None,
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False,
model_policy: Policy = None,
) -> None:
assert inference_config, "Please provide inference_config."
self._init_model()
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
self.inference_config = inference_config
self.model_config = model.config
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
self.dtype = torch.float32
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
self.dtype = torch.float16
model.half()
else:
self.dtype = torch.bfloat16
model.to(torch.bfloat16)
if model_policy is None:
model_policy = model_policy_map[self.model_config.model_type]()
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
self.model = self._shardformer(
model,
model_policy,
None,
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
)
self.verbose = verbose
if verbose:
self.logger = Logger()
self.logger = get_dist_logger(__name__)
def _init_model(self):
self.request_handler = RequestHandler(self.inference_config, self.model_config)
self.counter = count()
def _verify_config(self) -> None:
"""
Initialize model and distributed training environment(if needed).
May need to provide two different initialization methods:
1. 用户自定义(from local path)
2. 从checkpoint加载(hugging face)
Verify the input config
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
self.tokenizer, PreTrainedTokenizer
):
raise TypeError(
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
)
assert (
self.model.__class__.__name__ in _supported_models
), f"Model {self.model.__class__.__name__} is not supported."
def _shardformer(
self,
model: nn.Module,
model_policy: Policy,
stage_manager: PipelineStageManager = None,
tp_group: ProcessGroupMesh = None,
) -> nn.Module:
"""
Initialize ShardConfig and replace the model with shardformer.
Args:
model (nn.Module): Path or nn.Module of this model.
model_policy (Policy): The policy to shardformer model which is determined by the model type.
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
Returns:
nn.Module: _description_
"""
shardconfig = ShardConfig(
tensor_parallel_process_group=tp_group,
pipeline_stage_manager=stage_manager,
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
extra_kwargs={"quant": self.inference_config.quant_mode},
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()
def generate(
self,
generation_config: GenerationConfig = None,
) -> List[str]:
"""
Executing the inference step.
Args:
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
Returns:
List[str]: Inference result returned by one generation.
"""
def _verify_config(self):
self.generation_config = generation_config
output_list = []
while self.request_handler.check_unfinished_seqs():
output_list += self.step()
return output_list
def add_request(
self,
requests_id: List[int] = None,
prompts: List[str] = None,
prompts_token_ids: List[int] = None,
) -> None:
"""
Verify the configuration to avoid potential bugs.
Add requests.
Args:
requests_id (List[int], optional): The request ID. Defaults to None.
prompts (Union[List[str], optional): Input prompts. Defaults to None.
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
"""
def generate(self):
pass
block_size = self.inference_config.block_size
def step(self):
if prompts_token_ids is None:
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
prompts_token_ids = []
for prompt in prompts:
prompts_token_ids.append(self.tokenizer.encode(prompt))
prompts_num = len(prompts_token_ids)
for i in range(prompts_num):
if requests_id:
request_id = requests_id[i]
else:
request_id = next(self.counter)
if prompts == None:
prompt = None
else:
prompt = prompts[i]
sequence = Sequence(
request_id,
prompt,
prompts_token_ids[i],
block_size,
None,
None,
self.tokenizer.eos_token_id,
self.inference_config.max_output_len,
)
self.request_handler.add_sequence(sequence)
def step(self) -> List[str]:
"""
In each step, do the follows:
1. Run request_handler to update the kv cache and running input_ids
1. Run RequestHandler.schedule() and get the batch used for inference.
2. Run model to generate the next token
3. Check whether there is finied request and decode
3. Update waiting list and running list in RequestHandler and get finished sequences.
4. Decode and return finished sequences.
Returns:
List[str]: Decoded finished sequences generated by one step.
"""
if self.verbose:
self.logger.info("Running generation step")
output_list = []
self.request_handler.schedule()
# Uncomment if the development of RequestHandler is completed.
# logits = self.model(batch)
# self.request_handler.search_tokens(logits, self.generation_config)
finished_sequences = self.request_handler.update()
# Decode completed sentences.
for seq in finished_sequences:
if seq.prompt:
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
output_list.append(seq.prompt + output_str)
else:
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
output_list.append(output_str)
return output_list

View File

@@ -1,5 +1,7 @@
from typing import List
from colossalai.inference.struct import BatchInfo, Sequence
class RequestHandler:
"""
@@ -7,14 +9,17 @@ class RequestHandler:
During generation process, we call schedule function each iteration to update current batch.
Args:
cache_config: Configuration for initialize and manage kv cache.
inference_config: Store the configuration information related to inference.
model_config: The huggingface model config.
"""
def __init__(self, cache_config) -> None:
self.cache_config = cache_config
def __init__(self, inference_config, model_config) -> None:
self.inference_config = inference_config
self.model_config = model_config
self._init_cache()
self.waiting_list: List["Reqseq"] = []
self.running_list: List["Reqseq"] = []
self.waiting_list: List["Sequence"] = []
self.running_list: List["Sequence"] = []
self.batch = BatchInfo.init_batch()
def _init_cache(self):
"""
@@ -25,12 +30,17 @@ class RequestHandler:
"""
The main logic of request handler.
"""
# The code below is only used for testing engine and will be modified.
if self.waiting_list:
self.running_list = self.waiting_list
self.batch.add_seqs(self.running_list)
return self.batch
def add_sequence(self, reqseq: "Reqseq"):
def add_sequence(self, req_seq: "Sequence"):
"""
Add the request to waiting list.
"""
self.waiting_list.append(reqseq)
self.waiting_list.append(req_seq)
def abort_sequence(self, seq_id: str):
"""
@@ -39,10 +49,23 @@ class RequestHandler:
self._find_sequence(seq_id)
return
def _find_sequence(self, seq_id: str) -> "Reqseq":
def _find_sequence(self, seq_id: str) -> "Sequence":
"""
Find the request by seq_id.
"""
def check_unfinished_seqs(self) -> bool:
return self.waiting_list or self.running_list
return len(self.waiting_list) != 0 or len(self.running_list) != 0
def update(self):
"""
Update the waiting list and running list.
"""
# The code below is only used for testing engine and will be modified.
self.waiting_list = []
self.running_list = []
finished_sequences = list(self.batch.sequences_set)
self.batch.clear_batch()
return finished_sequences