mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[Inference] Add the logic of the inference engine (#5173)
* add infer_struct and infer_config * update codes * change InferConfig * Add hf_model_config to the engine * rm _get_hf_model_config * update codes * made adjustments according to the feedback from the reviewer. * update codes * add ci test for config and struct * Add the logic of the inference engine * update engine and test * Recover cache_manager.py * add logger * fix conflict * update codes * update codes * update model and tokenizer * fix add the logic about shardformer * change kvcache_manager docstring * add policy * fix ci bug in test_kvcache_manager.py * remove codes related o tokenizer and move model_policy * fix code style * add ordered_set to requirements-infer.txt * Delete extra empty lines * add ordered_set to requirements-test.txt
This commit is contained in:
committed by
FrankLeeeee
parent
93aeacca34
commit
8daee26989
@@ -1,65 +1,232 @@
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
from itertools import count
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from transformers import AutoConfig
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
||||
from .request_handler import RequestHandler
|
||||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
]
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""
|
||||
InferenceEngine is the core component for Inference.
|
||||
|
||||
It is responsible for launch the inference process, including:
|
||||
- Initialize model and distributed training environment(if needed)
|
||||
- Launch request_handler and corresponding kv cache manager
|
||||
- Receive requests and generate texts.
|
||||
- Log the generation process
|
||||
"""
|
||||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
tokenizer: Path of the tokenizer to use.
|
||||
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
model_policy ("Policy"): the policy to shardformer model. It will be determined by the model type if not provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: str = None,
|
||||
model: nn.Module,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: Optional["InferenceConfig"] = None,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
assert inference_config, "Please provide inference_config."
|
||||
|
||||
self._init_model()
|
||||
# cache_config may need to be modified later.
|
||||
# self.request_handler = RequestHandler(cache_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.hf_model_config = AutoConfig.from_pretrained(
|
||||
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
|
||||
if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32:
|
||||
self.dtype = torch.float32
|
||||
elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16:
|
||||
self.dtype = torch.float16
|
||||
model.half()
|
||||
else:
|
||||
self.dtype = torch.bfloat16
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
if model_policy is None:
|
||||
model_policy = model_policy_map[self.model_config.model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
||||
)
|
||||
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
self.logger = Logger()
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
def _init_model(self):
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.counter = count()
|
||||
|
||||
def _verify_config(self) -> None:
|
||||
"""
|
||||
Initialize model and distributed training environment(if needed).
|
||||
May need to provide two different initialization methods:
|
||||
1. 用户自定义(from local path)
|
||||
2. 从checkpoint加载(hugging face)
|
||||
Verify the input config
|
||||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"the model type must be nn.Module, but get {type(self.model)}")
|
||||
if not isinstance(self.tokenizer, PreTrainedTokenizerFast) and not isinstance(
|
||||
self.tokenizer, PreTrainedTokenizer
|
||||
):
|
||||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but get {type(self.tokenizer)}"
|
||||
)
|
||||
assert (
|
||||
self.model.__class__.__name__ in _supported_models
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_policy: Policy,
|
||||
stage_manager: PipelineStageManager = None,
|
||||
tp_group: ProcessGroupMesh = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Initialize ShardConfig and replace the model with shardformer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_policy (Policy): The policy to shardformer model which is determined by the model type.
|
||||
stage_manager (PipelineStageManager, optional): Used to manage pipeline stages. Defaults to None.
|
||||
tp_group (ProcessGroupMesh, optional): Used to manage the process TP group mesh. Defaults to None.
|
||||
|
||||
Returns:
|
||||
nn.Module: _description_
|
||||
"""
|
||||
shardconfig = ShardConfig(
|
||||
tensor_parallel_process_group=tp_group,
|
||||
pipeline_stage_manager=stage_manager,
|
||||
enable_tensor_parallelism=(self.inference_config.tp_size > 1),
|
||||
enable_fused_normalization=False,
|
||||
enable_all_optimization=False,
|
||||
enable_flash_attention=False,
|
||||
enable_jit_fused=False,
|
||||
enable_sequence_parallelism=False,
|
||||
extra_kwargs={"quant": self.inference_config.quant_mode},
|
||||
)
|
||||
shardformer = ShardFormer(shard_config=shardconfig)
|
||||
shard_model, _ = shardformer.optimize(model, model_policy)
|
||||
return shard_model.cuda()
|
||||
|
||||
def generate(
|
||||
self,
|
||||
generation_config: GenerationConfig = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Executing the inference step.
|
||||
|
||||
Args:
|
||||
generation_config (GenerationConfig, optional): Huggingface GenerationConfig used for inference. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[str]: Inference result returned by one generation.
|
||||
"""
|
||||
|
||||
def _verify_config(self):
|
||||
self.generation_config = generation_config
|
||||
|
||||
output_list = []
|
||||
|
||||
while self.request_handler.check_unfinished_seqs():
|
||||
output_list += self.step()
|
||||
|
||||
return output_list
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
requests_id: List[int] = None,
|
||||
prompts: List[str] = None,
|
||||
prompts_token_ids: List[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify the configuration to avoid potential bugs.
|
||||
Add requests.
|
||||
|
||||
Args:
|
||||
requests_id (List[int], optional): The request ID. Defaults to None.
|
||||
prompts (Union[List[str], optional): Input prompts. Defaults to None.
|
||||
prompts_token_ids (List[List[int]], optional): token ids of input prompts. Defaults to None.
|
||||
"""
|
||||
|
||||
def generate(self):
|
||||
pass
|
||||
block_size = self.inference_config.block_size
|
||||
|
||||
def step(self):
|
||||
if prompts_token_ids is None:
|
||||
assert prompts, "When the prompts_token_ids is none, the input prompt list must be provided."
|
||||
prompts_token_ids = []
|
||||
for prompt in prompts:
|
||||
prompts_token_ids.append(self.tokenizer.encode(prompt))
|
||||
|
||||
prompts_num = len(prompts_token_ids)
|
||||
|
||||
for i in range(prompts_num):
|
||||
if requests_id:
|
||||
request_id = requests_id[i]
|
||||
else:
|
||||
request_id = next(self.counter)
|
||||
if prompts == None:
|
||||
prompt = None
|
||||
else:
|
||||
prompt = prompts[i]
|
||||
sequence = Sequence(
|
||||
request_id,
|
||||
prompt,
|
||||
prompts_token_ids[i],
|
||||
block_size,
|
||||
None,
|
||||
None,
|
||||
self.tokenizer.eos_token_id,
|
||||
self.inference_config.max_output_len,
|
||||
)
|
||||
self.request_handler.add_sequence(sequence)
|
||||
|
||||
def step(self) -> List[str]:
|
||||
"""
|
||||
In each step, do the follows:
|
||||
1. Run request_handler to update the kv cache and running input_ids
|
||||
1. Run RequestHandler.schedule() and get the batch used for inference.
|
||||
2. Run model to generate the next token
|
||||
3. Check whether there is finied request and decode
|
||||
3. Update waiting list and running list in RequestHandler and get finished sequences.
|
||||
4. Decode and return finished sequences.
|
||||
|
||||
Returns:
|
||||
List[str]: Decoded finished sequences generated by one step.
|
||||
"""
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info("Running generation step")
|
||||
|
||||
output_list = []
|
||||
self.request_handler.schedule()
|
||||
|
||||
# Uncomment if the development of RequestHandler is completed.
|
||||
# logits = self.model(batch)
|
||||
# self.request_handler.search_tokens(logits, self.generation_config)
|
||||
|
||||
finished_sequences = self.request_handler.update()
|
||||
|
||||
# Decode completed sentences.
|
||||
for seq in finished_sequences:
|
||||
if seq.prompt:
|
||||
output_str = self.tokenizer.decode(seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(seq.prompt + output_str)
|
||||
else:
|
||||
output_str = self.tokenizer.decode(seq.input_token_id + seq.output_token_id, skip_special_tokens=True)
|
||||
output_list.append(output_str)
|
||||
|
||||
return output_list
|
||||
|
@@ -1,5 +1,7 @@
|
||||
from typing import List
|
||||
|
||||
from colossalai.inference.struct import BatchInfo, Sequence
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
"""
|
||||
@@ -7,14 +9,17 @@ class RequestHandler:
|
||||
During generation process, we call schedule function each iteration to update current batch.
|
||||
|
||||
Args:
|
||||
cache_config: Configuration for initialize and manage kv cache.
|
||||
inference_config: Store the configuration information related to inference.
|
||||
model_config: The huggingface model config.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_config) -> None:
|
||||
self.cache_config = cache_config
|
||||
def __init__(self, inference_config, model_config) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model_config
|
||||
self._init_cache()
|
||||
self.waiting_list: List["Reqseq"] = []
|
||||
self.running_list: List["Reqseq"] = []
|
||||
self.waiting_list: List["Sequence"] = []
|
||||
self.running_list: List["Sequence"] = []
|
||||
self.batch = BatchInfo.init_batch()
|
||||
|
||||
def _init_cache(self):
|
||||
"""
|
||||
@@ -25,12 +30,17 @@ class RequestHandler:
|
||||
"""
|
||||
The main logic of request handler.
|
||||
"""
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
if self.waiting_list:
|
||||
self.running_list = self.waiting_list
|
||||
self.batch.add_seqs(self.running_list)
|
||||
return self.batch
|
||||
|
||||
def add_sequence(self, reqseq: "Reqseq"):
|
||||
def add_sequence(self, req_seq: "Sequence"):
|
||||
"""
|
||||
Add the request to waiting list.
|
||||
"""
|
||||
self.waiting_list.append(reqseq)
|
||||
self.waiting_list.append(req_seq)
|
||||
|
||||
def abort_sequence(self, seq_id: str):
|
||||
"""
|
||||
@@ -39,10 +49,23 @@ class RequestHandler:
|
||||
self._find_sequence(seq_id)
|
||||
return
|
||||
|
||||
def _find_sequence(self, seq_id: str) -> "Reqseq":
|
||||
def _find_sequence(self, seq_id: str) -> "Sequence":
|
||||
"""
|
||||
Find the request by seq_id.
|
||||
"""
|
||||
|
||||
def check_unfinished_seqs(self) -> bool:
|
||||
return self.waiting_list or self.running_list
|
||||
return len(self.waiting_list) != 0 or len(self.running_list) != 0
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Update the waiting list and running list.
|
||||
"""
|
||||
|
||||
# The code below is only used for testing engine and will be modified.
|
||||
self.waiting_list = []
|
||||
self.running_list = []
|
||||
finished_sequences = list(self.batch.sequences_set)
|
||||
|
||||
self.batch.clear_batch()
|
||||
return finished_sequences
|
||||
|
Reference in New Issue
Block a user