ColossalAI/colossalai/inference/core/engine.py
yuehuayingxueluo fab9b931d9 [Inference]Add BatchInferState, Sequence and InferConfig (#5149)
* add infer_struct and infer_config

* update codes

* change InferConfig

* Add hf_model_config to the engine

* rm _get_hf_model_config

* update codes

* made adjustments according to the feedback from the reviewer.

* update codes

* add ci test for config and struct
2024-01-11 13:39:29 +00:00

66 lines
2.1 KiB
Python

from logging import Logger
from typing import Optional
from transformers import AutoConfig
from .config import InferenceConfig
class InferenceEngine:
"""
InferenceEngine is the core component for Inference.
It is responsible for launch the inference process, including:
- Initialize model and distributed training environment(if needed)
- Launch request_handler and corresponding kv cache manager
- Receive requests and generate texts.
- Log the generation process
Args:
tokenizer: Path of the tokenizer to use.
inference_config: We provide a unified config api for that wrapped all the configs. You can use it to replace the below configs.
verbose (bool): Determine whether or not to log the generation process.
"""
def __init__(
self,
tokenizer: str = None,
inference_config: Optional["InferenceConfig"] = None,
verbose: bool = False,
) -> None:
assert inference_config, "Please provide inference_config."
self._init_model()
# cache_config may need to be modified later.
# self.request_handler = RequestHandler(cache_config)
self.tokenizer = tokenizer
self.hf_model_config = AutoConfig.from_pretrained(
self.model, trust_remote_code=self.trust_remote_code, revision=self.revision
)
if verbose:
self.logger = Logger()
def _init_model(self):
"""
Initialize model and distributed training environment(if needed).
May need to provide two different initialization methods:
1. 用户自定义(from local path)
2. 从checkpoint加载(hugging face)
"""
def _verify_config(self):
"""
Verify the configuration to avoid potential bugs.
"""
def generate(self):
pass
def step(self):
"""
In each step, do the follows:
1. Run request_handler to update the kv cache and running input_ids
2. Run model to generate the next token
3. Check whether there is finied request and decode
"""