feat(model): support openai embedding model

This commit is contained in:
FangYin Cheng 2023-09-15 16:20:24 +08:00
parent e331936bff
commit 5aa9cb455e
15 changed files with 247 additions and 41 deletions

View File

@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2 # EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191 # EMBEDDING_TOKEN_LIMIT=8191
## Openai embedding model, See /pilot/model/parameter.py
# EMBEDDING_MODEL=proxy_openai
# proxy_openai_proxy_server_url=https://api.openai.com/v1
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
#*******************************************************************# #*******************************************************************#
#** DATABASE SETTINGS **# #** DATABASE SETTINGS **#

View File

@ -49,7 +49,7 @@ services:
capabilities: [gpu] capabilities: [gpu]
webserver: webserver:
image: eosphorosai/dbgpt:latest image: eosphorosai/dbgpt:latest
command: dbgpt start webserver --light command: dbgpt start webserver --light --remote_embedding
environment: environment:
- DBGPT_LOG_LEVEL=DEBUG - DBGPT_LOG_LEVEL=DEBUG
- LOCAL_DB_PATH=data/default_sqlite.db - LOCAL_DB_PATH=data/default_sqlite.db

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum from enum import Enum
import logging import logging
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
@self.app.on_event("startup") @self.app.on_event("startup")
async def startup_event(): async def startup_event():
"""ASGI app startup event handler.""" """ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start()) async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start() self.after_start()
@self.app.on_event("shutdown") @self.app.on_event("shutdown")

View File

@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
### default Local database connection configuration ### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql") self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "": if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1" self.LOCAL_DB_HOST = "127.0.0.1"
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_MANAGE = None self.LOCAL_DB_MANAGE = None
### LLM Model Service Configuration ### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy. ### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. ### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.

View File

@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
} }
# Load model config # Load model config

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING

View File

@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]: def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import EmbeddingModelParameters, WorkerType from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type") pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse() pre_args.parse()
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
if model_name is None: if model_name is None:
return None return None
if worker_type == WorkerType.TEXT2VEC: if worker_type == WorkerType.TEXT2VEC:
return [EmbeddingModelParameters] return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path) llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class() param_class = llm_adapter.model_param_class()

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingLoader:
def __init__(self) -> None:
pass
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import ( from pilot.model.parameter import (
EmbeddingModelParameters, EmbeddingModelParameters,
BaseEmbeddingModelParameters,
WorkerType, WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
) )
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker") logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker): class EmbeddingsModelWorker(ModelWorker):
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
) from exc ) from exc
self._embeddings_impl: Embeddings = None self._embeddings_impl: Embeddings = None
self._model_params = None self._model_params = None
self.model_name = None
self.model_path = None
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"): if model_path.endswith("/"):
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
return WorkerType.TEXT2VEC return WorkerType.TEXT2VEC
def model_param_class(self) -> Type: def model_param_class(self) -> Type:
return EmbeddingModelParameters return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self.model_name, EmbeddingModelParameters
)
def parse_parameters( def parse_parameters(
self, command_args: List[str] = None self, command_args: List[str] = None
) -> EmbeddingModelParameters: ) -> BaseEmbeddingModelParameters:
param_cls = self.model_param_class() param_cls = self.model_param_class()
return _parse_embedding_params( return _parse_embedding_params(
model_name=self.model_name, model_name=self.model_name,
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
command_args: List[str] = None, command_args: List[str] = None,
) -> None: ) -> None:
"""Start model worker""" """Start model worker"""
from langchain.embeddings import HuggingFaceEmbeddings
if not model_params: if not model_params:
model_params = self.parse_parameters(command_args) model_params = self.parse_parameters(command_args)
self._model_params = model_params self._model_params = model_params
self._embeddings_impl = self._loader.load(self.model_name, model_params)
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
def __del__(self): def __del__(self):
self.stop() self.stop()
@ -101,7 +104,7 @@ def _parse_embedding_params(
): ):
model_args = EnvArgumentParser() model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name) env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls, param_cls,
env_prefix=env_prefix, env_prefix=env_prefix,
command_args=command_args, command_args=command_args,

View File

@ -2,6 +2,7 @@ import asyncio
import itertools import itertools
import json import json
import os import os
import sys
import random import random
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
command_args: List[str] = None, command_args: List[str] = None,
) -> bool: ) -> bool:
if not command_args: if not command_args:
import sys
command_args = sys.argv[1:] command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params)) worker.load_worker(**asdict(worker_params))
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# TODO catch exception and shutdown if worker manager start failed async def start_worker_manager():
asyncio.create_task(worker_manager.start()) try:
await worker_manager.start()
except Exception as e:
logger.error(f"Error starting worker manager: {e}")
sys.exit(1)
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
@app.on_event("shutdown") @app.on_event("shutdown")
async def startup_event(): async def startup_event():

View File

@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
logger = logging.getLogger(__name__)
class RemoteModelWorker(ModelWorker): class RemoteModelWorker(ModelWorker):
def __init__(self) -> None: def __init__(self) -> None:
self.headers = {} self.headers = {}
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate stream""" """Asynchronous generate stream"""
import httpx import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
delimiter = b"\0" delimiter = b"\0"
buffer = b"" buffer = b""
url = self.worker_addr + "/generate_stream"
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
async with client.stream( async with client.stream(
"POST", "POST",
self.worker_addr + "/generate_stream", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate non stream""" """Asynchronous generate non stream"""
import httpx import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
url = self.worker_addr + "/generate"
logger.debug(f"Send async_generate to url {url}, params: {params}")
response = await client.post( response = await client.post(
self.worker_addr + "/generate", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
"""Get embeddings for input""" """Get embeddings for input"""
import requests import requests
url = self.worker_addr + "/embeddings"
logger.debug(f"Send embeddings to url {url}, params: {params}")
response = requests.post( response = requests.post(
self.worker_addr + "/embeddings", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
import httpx import httpx
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
url = self.worker_addr + "/embeddings"
logger.debug(f"Send async_embeddings to url {url}")
response = await client.post( response = await client.post(
self.worker_addr + "/embeddings", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,

View File

@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
@dataclass @dataclass
class EmbeddingModelParameters(BaseModelParameters): class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
device: Optional[str] = field( device: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
max_context_size: Optional[int] = field( max_context_size: Optional[int] = field(
default=4096, metadata={"help": "Maximum context size"} default=4096, metadata={"help": "Maximum context size"}
) )
@dataclass
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
proxy_server_url: str = field(
metadata={
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
},
)
proxy_api_key: str = field(
metadata={
"tags": "privacy",
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
},
)
device: Optional[str] = field(
default=None,
metadata={"help": "Device to run model. Not working for proxy embedding model"},
)
proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
},
)
proxy_backend: Optional[str] = field(
default="text-embedding-ada-002",
metadata={
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
},
)
proxy_deployment: Optional[str] = field(
default="text-embedding-ada-002",
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
}
for k, v in kwargs:
params[k] = v
return params
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
_update_embedding_config()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp from pilot.componet import SystemApp
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
DefaultEmbeddingFactory, DefaultEmbeddingFactory,
) )
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -46,18 +47,11 @@ def _initialize_embedding_model(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
) )
else: else:
from pilot.model.parameter import EmbeddingModelParameters logger.info(f"Register local LocalEmbeddingFactory")
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
system_app.register( system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
) )
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
raise NotImplementedError raise NotImplementedError
# Ignore model_name args # Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager) return RemoteEmbeddings(self._default_model_name, self._worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@ -37,6 +37,7 @@ opencv-python==4.7.0.72
iopath==0.1.10 iopath==0.1.10
tenacity==8.2.2 tenacity==8.2.2
peft peft
# TODO remove pycocoevalcap
pycocoevalcap pycocoevalcap
cpm_kernels cpm_kernels
umap-learn umap-learn

View File

@ -319,6 +319,14 @@ def all_datasource_requires():
setup_spec.extras["datasource"] = ["pymssql", "pymysql"] setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
def openai_requires():
"""
pip install "db-gpt[openai]"
"""
setup_spec.extras["openai"] = ["openai", "tiktoken"]
llama_cpp_python_cuda_requires()
def all_requires(): def all_requires():
requires = set() requires = set()
for _, pkgs in setup_spec.extras.items(): for _, pkgs in setup_spec.extras.items():
@ -339,6 +347,7 @@ llama_cpp_requires()
quantization_requires() quantization_requires()
all_vector_store_requires() all_vector_store_requires()
all_datasource_requires() all_datasource_requires()
openai_requires()
# must be last # must be last
all_requires() all_requires()