ColossalAI/colossalai/legacy/inference/dynamic_batching/ray_init_config.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

59 lines
1.6 KiB
Python

import logging
import yaml
from pydantic import BaseModel
logger = logging.getLogger(__name__)
class EngineArgsClass(BaseModel):
"""Config for Engine"""
model: str
tensor_parallel_size: int = 2
max_batch_size: int = 4
max_input_len: int = 128
max_output_len: int = 32
class RooterArgsClass(BaseModel):
"""Config for Rooter"""
max_total_token_num: int = 42
batch_max_tokens: int = 42
eos_id: int = 0
disable_log_stats: bool = False
log_stats_interval: int = 10
model: str
class RayInitConfig(BaseModel):
"""All-together configs without app router config"""
engine_config_data: EngineArgsClass
router_config_data: RooterArgsClass
@classmethod
def from_yaml_path(cls, path: str):
try:
with open(path, "r") as yaml_file:
try:
config = yaml.safe_load(yaml_file)
# serve deployment config
engine_config = config.get("engine_config", {})
router_config = config.get("router_config", {})
return cls(
engine_config_data=engine_config,
router_config_data=router_config,
)
except yaml.YAMLError as e:
logger.error(f"An Error occurred when parsing yaml: {e}")
raise
except FileNotFoundError:
logger.error(f"The file '{path}' does not exist!")
raise
except OSError as e:
logger.error(f"An Error occurred: {e}")
raise