Merge remote-tracking branch 'origin/main' into feat_rag_graph

This commit is contained in:
aries_ckt
2023-10-31 13:43:23 +08:00
18 changed files with 256 additions and 107 deletions

View File

@@ -194,6 +194,8 @@ class Config(metaclass=Singleton):
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
self.LLM_MODEL_PATH = os.getenv("LLM_MODEL_PATH")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.

View File

@@ -13,7 +13,7 @@ from pilot.utils.api_utils import (
_api_remote as api_remote,
_sync_api_remote as sync_api_remote,
)
from pilot.utils.utils import setup_logging
from pilot.utils.utils import setup_logging, setup_http_service_logging
logger = logging.getLogger(__name__)
@@ -149,6 +149,7 @@ def initialize_controller(
else:
import uvicorn
setup_http_service_logging()
app = FastAPI()
app.include_router(router, prefix="/api", tags=["Model"])
uvicorn.run(app, host=host, port=port, log_level="info")
@@ -179,7 +180,8 @@ def run_model_controller():
parser = EnvArgumentParser()
env_prefix = "controller_"
controller_params: ModelControllerParameters = parser.parse_args_into_dataclass(
ModelControllerParameters, env_prefix=env_prefix
ModelControllerParameters,
env_prefixes=[env_prefix],
)
setup_logging(

View File

@@ -76,7 +76,7 @@ class DefaultModelWorker(ModelWorker):
model_type = self.llm_adapter.model_type()
model_params: ModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
env_prefixes=[env_prefix, "LLM_"],
command_args=command_args,
model_name=self.model_name,
model_path=self.model_path,

View File

@@ -106,7 +106,7 @@ def _parse_embedding_params(
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
env_prefixes=[env_prefix],
command_args=command_args,
model_name=model_name,
model_path=model_path,

View File

@@ -38,7 +38,7 @@ from pilot.utils.parameter_utils import (
_dict_to_command_args,
_get_dict_from_obj,
)
from pilot.utils.utils import setup_logging
from pilot.utils.utils import setup_logging, setup_http_service_logging
from pilot.utils.tracer import initialize_tracer, root_tracer, SpanType, SpanTypeRunName
from pilot.utils.system_utils import get_system_info
@@ -735,6 +735,8 @@ def _setup_fastapi(
):
if not app:
app = FastAPI()
setup_http_service_logging()
if worker_params.standalone:
from pilot.model.cluster.controller.controller import initialize_controller
from pilot.model.cluster.controller.controller import (
@@ -781,7 +783,7 @@ def _parse_worker_params(
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
worker_params: ModelWorkerParameters = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
env_prefix=env_prefix,
env_prefixes=[env_prefix],
model_name=model_name,
model_path=model_path,
**kwargs,
@@ -790,7 +792,7 @@ def _parse_worker_params(
# Read parameters agein with prefix of model name.
new_worker_params = worker_args.parse_args_into_dataclass(
ModelWorkerParameters,
env_prefix=env_prefix,
env_prefixes=[env_prefix],
model_name=worker_params.model_name,
model_path=worker_params.model_path,
**kwargs,

View File

@@ -95,7 +95,7 @@ class ModelLoader:
env_prefix = env_prefix.replace("-", "_")
model_params = args_parser.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
env_prefixes=[env_prefix],
device=self.device,
model_path=self.model_path,
model_name=self.model_name,

View File

@@ -445,17 +445,47 @@ class VLLMModelAdaperWrapper(LLMModelAdaper):
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L212
register_conv_template(
Conversation(
name="internlm-chat",
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
# TODO feedback stop_str to fastchat
stop_str="<eoa>",
),
override=True,
name="aquila-legacy",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
roles=("### Human: ", "### Assistant: ", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="\n",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L227
register_conv_template(
Conversation(
name="aquila",
system_message="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant", "System"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep="###",
sep2="</s>",
stop_str=["</s>", "[UNK]"],
)
)
# source: https://huggingface.co/BAAI/AquilaChat2-34B/blob/4608b75855334b93329a771aee03869dbf7d88cc/predict.py#L242
register_conv_template(
Conversation(
name="aquila-v1",
roles=("<|startofpiece|>", "<|endofpiece|>", ""),
messages=(),
offset=0,
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_str=["</s>", "<|endoftext|>"],
)
)

View File

@@ -5,8 +5,6 @@ import os
from typing import List
import logging
import openai
from pilot.model.proxy.llms.proxy_model import ProxyModel
from pilot.model.parameter import ProxyModelParameters
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
@@ -15,6 +13,14 @@ logger = logging.getLogger(__name__)
def _initialize_openai(params: ProxyModelParameters):
try:
import openai
except ImportError as exc:
raise ValueError(
"Could not import python package: openai "
"Please install openai by command `pip install openai` "
) from exc
api_type = params.proxy_api_type or os.getenv("OPENAI_API_TYPE", "open_ai")
api_base = params.proxy_api_base or os.getenv(
@@ -106,6 +112,8 @@ def _build_request(model: ProxyModel, params):
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import openai
history, payloads = _build_request(model, params)
res = openai.ChatCompletion.create(messages=history, **payloads)
@@ -121,6 +129,8 @@ def chatgpt_generate_stream(
async def async_chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import openai
history, payloads = _build_request(model, params)
res = await openai.ChatCompletion.acreate(messages=history, **payloads)

View File

@@ -2,6 +2,7 @@ import os
import argparse
import sys
from typing import List
import logging
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@@ -39,6 +40,7 @@ from pilot.utils.utils import (
setup_logging,
_get_logging_level,
logging_str_to_uvicorn_level,
setup_http_service_logging,
)
from pilot.utils.tracer import root_tracer, initialize_tracer, SpanType, SpanTypeRunName
from pilot.utils.parameter_utils import _get_dict_from_obj
@@ -127,6 +129,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
setup_logging(
"pilot", logging_level=param.log_level, logger_filename=param.log_file
)
# Before start
system_app.before_start()
@@ -141,7 +144,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
model_name = param.model_name or CFG.LLM_MODEL
model_path = LLM_MODEL_CONFIG.get(model_name)
model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
@@ -180,6 +183,7 @@ def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
def run_uvicorn(param: WebWerverParameters):
import uvicorn
setup_http_service_logging()
uvicorn.run(
app,
host=param.host,

View File

@@ -190,6 +190,17 @@ def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_valu
)
def _genenv_ignoring_key_case_with_prefixes(
env_key: str, env_prefixes: List[str] = None, default_value=None
) -> str:
if env_prefixes:
for env_prefix in env_prefixes:
env_var_value = _genenv_ignoring_key_case(env_key, env_prefix)
if env_var_value:
return env_var_value
return _genenv_ignoring_key_case(env_key, default_value=default_value)
class EnvArgumentParser:
@staticmethod
def get_env_prefix(env_key: str) -> str:
@@ -201,18 +212,16 @@ class EnvArgumentParser:
def parse_args_into_dataclass(
self,
dataclass_type: Type,
env_prefix: str = None,
env_prefixes: List[str] = None,
command_args: List[str] = None,
**kwargs,
) -> Any:
"""Parse parameters from environment variables and command lines and populate them into data class"""
parser = argparse.ArgumentParser()
for field in fields(dataclass_type):
env_var_value = _genenv_ignoring_key_case(field.name, env_prefix)
if not env_var_value:
# Read without env prefix
env_var_value = _genenv_ignoring_key_case(field.name)
env_var_value = _genenv_ignoring_key_case_with_prefixes(
field.name, env_prefixes
)
if env_var_value:
env_var_value = env_var_value.strip()
if field.type is int or field.type == Optional[int]:

View File

@@ -3,6 +3,8 @@
import logging
import logging.handlers
from typing import Any, List
import os
import sys
import asyncio
@@ -186,3 +188,42 @@ def logging_str_to_uvicorn_level(log_level_str):
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")
class EndpointFilter(logging.Filter):
"""Disable access log on certain endpoint
source: https://github.com/encode/starlette/issues/864#issuecomment-1254987630
"""
def __init__(
self,
path: str,
*args: Any,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
self._path = path
def filter(self, record: logging.LogRecord) -> bool:
return record.getMessage().find(self._path) == -1
def setup_http_service_logging(exclude_paths: List[str] = None):
"""Setup http service logging
Now just disable some logs
Args:
exclude_paths (List[str]): The paths to disable log
"""
if not exclude_paths:
# Not show heartbeat log
exclude_paths = ["/api/controller/heartbeat"]
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
for path in exclude_paths:
uvicorn_logger.addFilter(EndpointFilter(path=path))
httpx_logger = logging.getLogger("httpx")
if httpx_logger:
httpx_logger.setLevel(logging.WARNING)