diff --git a/.env.template b/.env.template index 77c94ad66..47b7c3b99 100644 --- a/.env.template +++ b/.env.template @@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5 # EMBEDDING_TOKENIZER=all-MiniLM-L6-v2 # EMBEDDING_TOKEN_LIMIT=8191 +## Openai embedding model, See /pilot/model/parameter.py +# EMBEDDING_MODEL=proxy_openai +# proxy_openai_proxy_server_url=https://api.openai.com/v1 +# proxy_openai_proxy_api_key={your-openai-sk} +# proxy_openai_proxy_backend=text-embedding-ada-002 + #*******************************************************************# #** DATABASE SETTINGS **# diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index 8d4763532..df2faed91 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -49,7 +49,7 @@ services: capabilities: [gpu] webserver: image: eosphorosai/dbgpt:latest - command: dbgpt start webserver --light + command: dbgpt start webserver --light --remote_embedding environment: - DBGPT_LOG_LEVEL=DEBUG - LOCAL_DB_PATH=data/default_sqlite.db diff --git a/pilot/componet.py b/pilot/componet.py index 0897b3365..8697d9560 100644 --- a/pilot/componet.py +++ b/pilot/componet.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +import sys from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING from enum import Enum import logging @@ -152,8 +153,15 @@ class SystemApp(LifeCycle): @self.app.on_event("startup") async def startup_event(): """ASGI app startup event handler.""" - # TODO catch exception and shutdown if worker manager start failed - asyncio.create_task(self.async_after_start()) + + async def _startup_func(): + try: + await self.async_after_start() + except Exception as e: + logger.error(f"Error starting system app: {e}") + sys.exit(1) + + asyncio.create_task(_startup_func()) self.after_start() @self.app.on_event("shutdown") diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 0276c2a17..e068613c1 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -128,8 +128,8 @@ class Config(metaclass=Singleton): ### default Local database connection configuration self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") - self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") - self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql") + self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db") + self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite") if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "": self.LOCAL_DB_HOST = "127.0.0.1" @@ -141,7 +141,7 @@ class Config(metaclass=Singleton): self.LOCAL_DB_MANAGE = None ### LLM Model Service Configuration - self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") + self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5") ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" ### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy. ### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index e80f9c8e8..e513cf898 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = { "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), + "proxy_openai": "proxy_openai", + "proxy_azure": "proxy_azure", } # Load model config diff --git a/pilot/embedding_engine/embedding_factory.py b/pilot/embedding_engine/embedding_factory.py index b9d17ad83..c6f51f4c9 100644 --- a/pilot/embedding_engine/embedding_factory.py +++ b/pilot/embedding_engine/embedding_factory.py @@ -1,3 +1,4 @@ +from __future__ import annotations from abc import ABC, abstractmethod from typing import Any, Type, TYPE_CHECKING diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 8fe4a9057..c29427b1d 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: def _dynamic_model_parser() -> Callable[[None], List[Type]]: from pilot.utils.parameter_utils import _SimpleArgParser - from pilot.model.parameter import EmbeddingModelParameters, WorkerType + from pilot.model.parameter import ( + EmbeddingModelParameters, + WorkerType, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + ) pre_args = _SimpleArgParser("model_name", "model_path", "worker_type") pre_args.parse() @@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]: if model_name is None: return None if worker_type == WorkerType.TEXT2VEC: - return [EmbeddingModelParameters] + return [ + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + model_name, EmbeddingModelParameters + ) + ] llm_adapter = get_llm_model_adapter(model_name, model_path) param_class = llm_adapter.model_param_class() diff --git a/pilot/model/cluster/embedding/loader.py b/pilot/model/cluster/embedding/loader.py new file mode 100644 index 000000000..63f6c452d --- /dev/null +++ b/pilot/model/cluster/embedding/loader.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pilot.model.parameter import BaseEmbeddingModelParameters + +if TYPE_CHECKING: + from langchain.embeddings.base import Embeddings + + +class EmbeddingLoader: + def __init__(self) -> None: + pass + + def load( + self, model_name: str, param: BaseEmbeddingModelParameters + ) -> "Embeddings": + # add more models + if model_name in ["proxy_openai", "proxy_azure"]: + from langchain.embeddings import OpenAIEmbeddings + + return OpenAIEmbeddings(**param.build_kwargs()) + else: + from langchain.embeddings import HuggingFaceEmbeddings + + kwargs = param.build_kwargs(model_name=param.model_path) + return HuggingFaceEmbeddings(**kwargs) diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py index a0bfd66bc..a9f934a1c 100644 --- a/pilot/model/cluster/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device from pilot.model.loader import _get_model_real_path from pilot.model.parameter import ( EmbeddingModelParameters, + BaseEmbeddingModelParameters, WorkerType, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, ) from pilot.model.cluster.worker_base import ModelWorker +from pilot.model.cluster.embedding.loader import EmbeddingLoader from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.parameter_utils import EnvArgumentParser -logger = logging.getLogger("model_worker") +logger = logging.getLogger(__name__) class EmbeddingsModelWorker(ModelWorker): @@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker): ) from exc self._embeddings_impl: Embeddings = None self._model_params = None + self.model_name = None + self.model_path = None + self._loader = EmbeddingLoader() def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: if model_path.endswith("/"): @@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker): return WorkerType.TEXT2VEC def model_param_class(self) -> Type: - return EmbeddingModelParameters + return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + self.model_name, EmbeddingModelParameters + ) def parse_parameters( self, command_args: List[str] = None - ) -> EmbeddingModelParameters: + ) -> BaseEmbeddingModelParameters: param_cls = self.model_param_class() return _parse_embedding_params( model_name=self.model_name, @@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker): command_args: List[str] = None, ) -> None: """Start model worker""" - from langchain.embeddings import HuggingFaceEmbeddings - if not model_params: model_params = self.parse_parameters(command_args) self._model_params = model_params - - kwargs = model_params.build_kwargs(model_name=model_params.model_path) - logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}") - self._embeddings_impl = HuggingFaceEmbeddings(**kwargs) + self._embeddings_impl = self._loader.load(self.model_name, model_params) def __del__(self): self.stop() @@ -101,7 +104,7 @@ def _parse_embedding_params( ): model_args = EnvArgumentParser() env_prefix = EnvArgumentParser.get_env_prefix(model_name) - model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( + model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass( param_cls, env_prefix=env_prefix, command_args=command_args, diff --git a/pilot/model/cluster/worker/manager.py b/pilot/model/cluster/worker/manager.py index 4791d8caa..ff8fd7767 100644 --- a/pilot/model/cluster/worker/manager.py +++ b/pilot/model/cluster/worker/manager.py @@ -2,6 +2,7 @@ import asyncio import itertools import json import os +import sys import random import time from concurrent.futures import ThreadPoolExecutor @@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager): command_args: List[str] = None, ) -> bool: if not command_args: - import sys - command_args = sys.argv[1:] worker.load_worker(**asdict(worker_params)) @@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None): @app.on_event("startup") async def startup_event(): - # TODO catch exception and shutdown if worker manager start failed - asyncio.create_task(worker_manager.start()) + async def start_worker_manager(): + try: + await worker_manager.start() + except Exception as e: + logger.error(f"Error starting worker manager: {e}") + sys.exit(1) + + # It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller) + asyncio.create_task(start_worker_manager()) @app.on_event("shutdown") async def startup_event(): diff --git a/pilot/model/cluster/worker/remote_worker.py b/pilot/model/cluster/worker/remote_worker.py index a0c306bbf..f974ba714 100644 --- a/pilot/model/cluster/worker/remote_worker.py +++ b/pilot/model/cluster/worker/remote_worker.py @@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters from pilot.model.cluster.worker_base import ModelWorker +logger = logging.getLogger(__name__) + + class RemoteModelWorker(ModelWorker): def __init__(self) -> None: self.headers = {} @@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker): """Asynchronous generate stream""" import httpx - logging.debug(f"Send async_generate_stream, params: {params}") async with httpx.AsyncClient() as client: delimiter = b"\0" buffer = b"" + url = self.worker_addr + "/generate_stream" + logger.debug(f"Send async_generate_stream to url {url}, params: {params}") async with client.stream( "POST", - self.worker_addr + "/generate_stream", + url, headers=self.headers, json=params, timeout=self.timeout, @@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker): """Asynchronous generate non stream""" import httpx - logging.debug(f"Send async_generate_stream, params: {params}") async with httpx.AsyncClient() as client: + url = self.worker_addr + "/generate" + logger.debug(f"Send async_generate to url {url}, params: {params}") response = await client.post( - self.worker_addr + "/generate", + url, headers=self.headers, json=params, timeout=self.timeout, @@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker): """Get embeddings for input""" import requests + url = self.worker_addr + "/embeddings" + logger.debug(f"Send embeddings to url {url}, params: {params}") response = requests.post( - self.worker_addr + "/embeddings", + url, headers=self.headers, json=params, timeout=self.timeout, @@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker): import httpx async with httpx.AsyncClient() as client: + url = self.worker_addr + "/embeddings" + logger.debug(f"Send async_embeddings to url {url}") response = await client.post( - self.worker_addr + "/embeddings", + url, headers=self.headers, json=params, timeout=self.timeout, diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 400ba0396..c7917a95f 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters): @dataclass -class EmbeddingModelParameters(BaseModelParameters): +class BaseEmbeddingModelParameters(BaseModelParameters): + def build_kwargs(self, **kwargs) -> Dict: + pass + + +@dataclass +class EmbeddingModelParameters(BaseEmbeddingModelParameters): device: Optional[str] = field( default=None, metadata={ @@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters): max_context_size: Optional[int] = field( default=4096, metadata={"help": "Maximum context size"} ) + + +@dataclass +class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): + proxy_server_url: str = field( + metadata={ + "help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1" + }, + ) + proxy_api_key: str = field( + metadata={ + "tags": "privacy", + "help": "The api key of the current embedding model(OPENAI_API_KEY)", + }, + ) + device: Optional[str] = field( + default=None, + metadata={"help": "Device to run model. Not working for proxy embedding model"}, + ) + proxy_api_type: Optional[str] = field( + default=None, + metadata={ + "help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure" + }, + ) + proxy_api_version: Optional[str] = field( + default=None, + metadata={ + "help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)" + }, + ) + proxy_backend: Optional[str] = field( + default="text-embedding-ada-002", + metadata={ + "help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002" + }, + ) + + proxy_deployment: Optional[str] = field( + default="text-embedding-ada-002", + metadata={"help": "Tto support Azure OpenAI Service custom deployment names"}, + ) + + def build_kwargs(self, **kwargs) -> Dict: + params = { + "openai_api_base": self.proxy_server_url, + "openai_api_key": self.proxy_api_key, + "openai_api_type": self.proxy_api_type if self.proxy_api_type else None, + "openai_api_version": self.proxy_api_version + if self.proxy_api_version + else None, + "model": self.proxy_backend, + "deployment": self.proxy_deployment + if self.proxy_deployment + else self.proxy_backend, + } + for k, v in kwargs: + params[k] = v + return params + + +_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { + ProxyEmbeddingParameters: "proxy_openai,proxy_azure" +} + +EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} + + +def _update_embedding_config(): + global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG + for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items(): + models = [m.strip() for m in models.split(",")] + for model in models: + if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG: + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls + + +_update_embedding_config() diff --git a/pilot/server/componet_configs.py b/pilot/server/componet_configs.py index 755f13b21..d46b626ca 100644 --- a/pilot/server/componet_configs.py +++ b/pilot/server/componet_configs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any, Type, TYPE_CHECKING from pilot.componet import SystemApp @@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import ( DefaultEmbeddingFactory, ) from pilot.server.base import WebWerverParameters -from pilot.utils.parameter_utils import EnvArgumentParser if TYPE_CHECKING: from langchain.embeddings.base import Embeddings @@ -46,18 +47,11 @@ def _initialize_embedding_model( RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name ) else: - from pilot.model.parameter import EmbeddingModelParameters - from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params - - model_params: EmbeddingModelParameters = _parse_embedding_params( - model_name=embedding_model_name, - model_path=embedding_model_path, - param_cls=EmbeddingModelParameters, - ) - kwargs = model_params.build_kwargs(model_name=embedding_model_path) - logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}") + logger.info(f"Register local LocalEmbeddingFactory") system_app.register( - DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs + LocalEmbeddingFactory, + default_model_name=embedding_model_name, + default_model_path=embedding_model_path, ) @@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory): raise NotImplementedError # Ignore model_name args return RemoteEmbeddings(self._default_model_name, self._worker_manager) + + +class LocalEmbeddingFactory(EmbeddingFactory): + def __init__( + self, + system_app, + default_model_name: str = None, + default_model_path: str = None, + **kwargs: Any, + ) -> None: + super().__init__(system_app=system_app) + self._default_model_name = default_model_name + self._default_model_path = default_model_path + self._kwargs = kwargs + self._model = self._load_model() + + def init_app(self, system_app): + pass + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "Embeddings": + if embedding_cls: + raise NotImplementedError + return self._model + + def _load_model(self) -> "Embeddings": + from pilot.model.parameter import ( + EmbeddingModelParameters, + BaseEmbeddingModelParameters, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + ) + from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params + from pilot.model.cluster.embedding.loader import EmbeddingLoader + + param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + self._default_model_name, EmbeddingModelParameters + ) + model_params: BaseEmbeddingModelParameters = _parse_embedding_params( + model_name=self._default_model_name, + model_path=self._default_model_path, + param_cls=param_cls, + **self._kwargs, + ) + logger.info(model_params) + loader = EmbeddingLoader() + # Ignore model_name args + return loader.load(self._default_model_name, model_params) diff --git a/requirements.txt b/requirements.txt index 130d7ccd2..9a768ec6d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,6 +37,7 @@ opencv-python==4.7.0.72 iopath==0.1.10 tenacity==8.2.2 peft +# TODO remove pycocoevalcap pycocoevalcap cpm_kernels umap-learn diff --git a/setup.py b/setup.py index 4c5bd0902..60e2643a6 100644 --- a/setup.py +++ b/setup.py @@ -319,6 +319,14 @@ def all_datasource_requires(): setup_spec.extras["datasource"] = ["pymssql", "pymysql"] +def openai_requires(): + """ + pip install "db-gpt[openai]" + """ + setup_spec.extras["openai"] = ["openai", "tiktoken"] + llama_cpp_python_cuda_requires() + + def all_requires(): requires = set() for _, pkgs in setup_spec.extras.items(): @@ -339,6 +347,7 @@ llama_cpp_requires() quantization_requires() all_vector_store_requires() all_datasource_requires() +openai_requires() # must be last all_requires()