mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
Merge remote-tracking branch 'origin/main' into feat_rag_graph
This commit is contained in:
@@ -194,6 +194,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
|
||||
self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH")
|
||||
|
||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||
|
@@ -13,7 +13,7 @@ from pilot.utils.api_utils import (
|
||||
_api_remote as api_remote,
|
||||
_sync_api_remote as sync_api_remote,
|
||||
)
|
||||
from pilot.utils.utils import setup_logging
|
||||
from pilot.utils.utils import setup_logging, setup_http_service_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -149,6 +149,7 @@ def initialize_controller(
|
||||
else:
|
||||
import uvicorn
|
||||
|
||||
setup_http_service_logging()
|
||||
app = FastAPI()
|
||||
app.include_router(router, prefix="/api", tags=["Model"])
|
||||
uvicorn.run(app, host=host, port=port, log_level="info")
|
||||
@@ -179,7 +180,8 @@ def run_model_controller():
|
||||
parser = EnvArgumentParser()
|
||||
env_prefix = "controller_"
|
||||
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
|
||||
ModelControllerParameters, env_prefix=env_prefix
|
||||
ModelControllerParameters,
|
||||
env_prefixes=[env_prefix],
|
||||
)
|
||||
|
||||
setup_logging(
|
||||
|
@@ -76,7 +76,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_type = self.llm_adapter.model_type()
|
||||
model_params: ModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
env_prefixes=[env_prefix, "LLM_"],
|
||||
command_args=command_args,
|
||||
model_name=self.model_name,
|
||||
model_path=self.model_path,
|
||||
|
@@ -106,7 +106,7 @@ def _parse_embedding_params(
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
env_prefixes=[env_prefix],
|
||||
command_args=command_args,
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
|
@@ -38,7 +38,7 @@ from pilot.utils.parameter_utils import (
|
||||
_dict_to_command_args,
|
||||
_get_dict_from_obj,
|
||||
)
|
||||
from pilot.utils.utils import setup_logging
|
||||
from pilot.utils.utils import setup_logging, setup_http_service_logging
|
||||
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
|
||||
from pilot.utils.system_utils import get_system_info
|
||||
|
||||
@@ -735,6 +735,8 @@ def _setup_fastapi(
|
||||
):
|
||||
if not app:
|
||||
app = FastAPI()
|
||||
setup_http_service_logging()
|
||||
|
||||
if worker_params.standalone:
|
||||
from pilot.model.cluster.controller.controller import initialize_controller
|
||||
from pilot.model.cluster.controller.controller import (
|
||||
@@ -781,7 +783,7 @@ def _parse_worker_params(
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
|
||||
ModelWorkerParameters,
|
||||
env_prefix=env_prefix,
|
||||
env_prefixes=[env_prefix],
|
||||
model_name=model_name,
|
||||
model_path=model_path,
|
||||
**kwargs,
|
||||
@@ -790,7 +792,7 @@ def _parse_worker_params(
|
||||
# Read parameters agein with prefix of model name.
|
||||
new_worker_params = worker_args.parse_args_into_dataclass(
|
||||
ModelWorkerParameters,
|
||||
env_prefix=env_prefix,
|
||||
env_prefixes=[env_prefix],
|
||||
model_name=worker_params.model_name,
|
||||
model_path=worker_params.model_path,
|
||||
**kwargs,
|
||||
|
@@ -95,7 +95,7 @@ class ModelLoader:
|
||||
env_prefix = env_prefix.replace("-", "_")
|
||||
model_params = args_parser.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
env_prefixes=[env_prefix],
|
||||
device=self.device,
|
||||
model_path=self.model_path,
|
||||
model_name=self.model_name,
|
||||
|
@@ -445,17 +445,47 @@ class VLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
|
||||
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
||||
# We also recommend that you modify it directly in the fastchat repository.
|
||||
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="internlm-chat",
|
||||
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
|
||||
roles=("<|User|>", "<|Bot|>"),
|
||||
sep_style=SeparatorStyle.CHATINTERN,
|
||||
sep="<eoh>",
|
||||
sep2="<eoa>",
|
||||
stop_token_ids=[1, 103028],
|
||||
# TODO feedback stop_str to fastchat
|
||||
stop_str="<eoa>",
|
||||
),
|
||||
override=True,
|
||||
name="aquila-legacy",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
||||
roles=("### Human: ", "### Assistant: ", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="\n",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila",
|
||||
system_message="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
roles=("Human", "Assistant", "System"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||
sep="###",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "[UNK]"],
|
||||
)
|
||||
)
|
||||
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="aquila-v1",
|
||||
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.NO_COLON_TWO,
|
||||
sep="",
|
||||
sep2="</s>",
|
||||
stop_str=["</s>", "<|endoftext|>"],
|
||||
)
|
||||
)
|
||||
|
@@ -5,8 +5,6 @@ import os
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
import openai
|
||||
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
from pilot.model.parameter import ProxyModelParameters
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
@@ -15,6 +13,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _initialize_openai(params: ProxyModelParameters):
|
||||
try:
|
||||
import openai
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: openai "
|
||||
"Please install openai by command `pip install openai` "
|
||||
) from exc
|
||||
|
||||
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
|
||||
|
||||
api_base = params.proxy_api_base or os.getenv(
|
||||
@@ -106,6 +112,8 @@ def _build_request(model: ProxyModel, params):
|
||||
def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = openai.ChatCompletion.create(messages=history, **payloads)
|
||||
@@ -121,6 +129,8 @@ def chatgpt_generate_stream(
|
||||
async def async_chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
import openai
|
||||
|
||||
history, payloads = _build_request(model, params)
|
||||
|
||||
res = await openai.ChatCompletion.acreate(messages=history, **payloads)
|
||||
|
@@ -2,6 +2,7 @@ import os
|
||||
import argparse
|
||||
import sys
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@@ -39,6 +40,7 @@ from pilot.utils.utils import (
|
||||
setup_logging,
|
||||
_get_logging_level,
|
||||
logging_str_to_uvicorn_level,
|
||||
setup_http_service_logging,
|
||||
)
|
||||
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
|
||||
from pilot.utils.parameter_utils import _get_dict_from_obj
|
||||
@@ -127,6 +129,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
setup_logging(
|
||||
"pilot", logging_level=param.log_level, logger_filename=param.log_file
|
||||
)
|
||||
|
||||
# Before start
|
||||
system_app.before_start()
|
||||
|
||||
@@ -141,7 +144,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
|
||||
model_name = param.model_name or CFG.LLM_MODEL
|
||||
|
||||
model_path = LLM_MODEL_CONFIG.get(model_name)
|
||||
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
@@ -180,6 +183,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
def run_uvicorn(param: WebWerverParameters):
|
||||
import uvicorn
|
||||
|
||||
setup_http_service_logging()
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=param.host,
|
||||
|
@@ -190,6 +190,17 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
|
||||
)
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case_with_prefixes(
|
||||
env_key: str, env_prefixes: List[str] = None, default_value=None
|
||||
) -> str:
|
||||
if env_prefixes:
|
||||
for env_prefix in env_prefixes:
|
||||
env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
|
||||
if env_var_value:
|
||||
return env_var_value
|
||||
return _genenv_ignoring_key_case(env_key, default_value=default_value)
|
||||
|
||||
|
||||
class EnvArgumentParser:
|
||||
@staticmethod
|
||||
def get_env_prefix(env_key: str) -> str:
|
||||
@@ -201,18 +212,16 @@ class EnvArgumentParser:
|
||||
def parse_args_into_dataclass(
|
||||
self,
|
||||
dataclass_type: Type,
|
||||
env_prefix: str = None,
|
||||
env_prefixes: List[str] = None,
|
||||
command_args: List[str] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Parse parameters from environment variables and command lines and populate them into data class"""
|
||||
parser = argparse.ArgumentParser()
|
||||
for field in fields(dataclass_type):
|
||||
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
|
||||
if not env_var_value:
|
||||
# Read without env prefix
|
||||
env_var_value = _genenv_ignoring_key_case(field.name)
|
||||
|
||||
env_var_value = _genenv_ignoring_key_case_with_prefixes(
|
||||
field.name, env_prefixes
|
||||
)
|
||||
if env_var_value:
|
||||
env_var_value = env_var_value.strip()
|
||||
if field.type is int or field.type == Optional[int]:
|
||||
|
@@ -3,6 +3,8 @@
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
from typing import Any, List
|
||||
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
@@ -186,3 +188,42 @@ def logging_str_to_uvicorn_level(log_level_str):
|
||||
"NOTSET": "info",
|
||||
}
|
||||
return level_str_mapping.get(log_level_str.upper(), "info")
|
||||
|
||||
|
||||
class EndpointFilter(logging.Filter):
|
||||
"""Disable access log on certain endpoint
|
||||
|
||||
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._path = path
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
return record.getMessage().find(self._path) == -1
|
||||
|
||||
|
||||
def setup_http_service_logging(exclude_paths: List[str] = None):
|
||||
"""Setup http service logging
|
||||
|
||||
Now just disable some logs
|
||||
|
||||
Args:
|
||||
exclude_paths (List[str]): The paths to disable log
|
||||
"""
|
||||
if not exclude_paths:
|
||||
# Not show heartbeat log
|
||||
exclude_paths = ["/api/controller/heartbeat"]
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
if uvicorn_logger:
|
||||
for path in exclude_paths:
|
||||
uvicorn_logger.addFilter(EndpointFilter(path=path))
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
if httpx_logger:
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
Reference in New Issue
Block a user