feat(model): support openai embedding model

This commit is contained in:
FangYin Cheng 2023-09-15 16:20:24 +08:00
parent e331936bff
commit 5aa9cb455e
15 changed files with 247 additions and 41 deletions

View File

@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191
## Openai embedding model, See /pilot/model/parameter.py
# EMBEDDING_MODEL=proxy_openai
# proxy_openai_proxy_server_url=https://api.openai.com/v1
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
#*******************************************************************#
#** DATABASE SETTINGS **#

View File

@ -49,7 +49,7 @@ services:
capabilities: [gpu]
webserver:
image: eosphorosai/dbgpt:latest
command: dbgpt start webserver --light
command: dbgpt start webserver --light --remote_embedding
environment:
- DBGPT_LOG_LEVEL=DEBUG
- LOCAL_DB_PATH=data/default_sqlite.db

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start())
async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start()
@self.app.on_event("shutdown")

View File

@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1"
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_MANAGE = None
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.

View File

@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
}
# Load model config

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING

View File

@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
if model_name is None:
return None
if worker_type == WorkerType.TEXT2VEC:
return [EmbeddingModelParameters]
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingLoader:
def __init__(self) -> None:
pass
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker")
logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
) from exc
self._embeddings_impl: Embeddings = None
self._model_params = None
self.model_name = None
self.model_path = None
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
return WorkerType.TEXT2VEC
def model_param_class(self) -> Type:
return EmbeddingModelParameters
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self.model_name, EmbeddingModelParameters
)
def parse_parameters(
self, command_args: List[str] = None
) -> EmbeddingModelParameters:
) -> BaseEmbeddingModelParameters:
param_cls = self.model_param_class()
return _parse_embedding_params(
model_name=self.model_name,
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
command_args: List[str] = None,
) -> None:
"""Start model worker"""
from langchain.embeddings import HuggingFaceEmbeddings
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
self._embeddings_impl = self._loader.load(self.model_name, model_params)
def __del__(self):
self.stop()
@ -101,7 +104,7 @@ def _parse_embedding_params(
):
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,

View File

@ -2,6 +2,7 @@ import asyncio
import itertools
import json
import os
import sys
import random
import time
from concurrent.futures import ThreadPoolExecutor
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
command_args: List[str] = None,
) -> bool:
if not command_args:
import sys
command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params))
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(worker_manager.start())
async def start_worker_manager():
try:
await worker_manager.start()
except Exception as e:
logger.error(f"Error starting worker manager: {e}")
sys.exit(1)
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
@app.on_event("shutdown")
async def startup_event():

View File

@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
logger = logging.getLogger(__name__)
class RemoteModelWorker(ModelWorker):
def __init__(self) -> None:
self.headers = {}
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate stream"""
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
delimiter = b"\0"
buffer = b""
url = self.worker_addr + "/generate_stream"
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
async with client.stream(
"POST",
self.worker_addr + "/generate_stream",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate non stream"""
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/generate"
logger.debug(f"Send async_generate to url {url}, params: {params}")
response = await client.post(
self.worker_addr + "/generate",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
"""Get embeddings for input"""
import requests
url = self.worker_addr + "/embeddings"
logger.debug(f"Send embeddings to url {url}, params: {params}")
response = requests.post(
self.worker_addr + "/embeddings",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/embeddings"
logger.debug(f"Send async_embeddings to url {url}")
response = await client.post(
self.worker_addr + "/embeddings",
url,
headers=self.headers,
json=params,
timeout=self.timeout,

View File

@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
@dataclass
class EmbeddingModelParameters(BaseModelParameters):
class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
device: Optional[str] = field(
default=None,
metadata={
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
max_context_size: Optional[int] = field(
default=4096, metadata={"help": "Maximum context size"}
)
@dataclass
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
proxy_server_url: str = field(
metadata={
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
},
)
proxy_api_key: str = field(
metadata={
"tags": "privacy",
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
},
)
device: Optional[str] = field(
default=None,
metadata={"help": "Device to run model. Not working for proxy embedding model"},
)
proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
},
)
proxy_backend: Optional[str] = field(
default="text-embedding-ada-002",
metadata={
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
},
)
proxy_deployment: Optional[str] = field(
default="text-embedding-ada-002",
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
}
for k, v in kwargs:
params[k] = v
return params
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
_update_embedding_config()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
DefaultEmbeddingFactory,
)
from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -46,18 +47,11 @@ def _initialize_embedding_model(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else:
from pilot.model.parameter import EmbeddingModelParameters
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
raise NotImplementedError
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@ -37,6 +37,7 @@ opencv-python==4.7.0.72
iopath==0.1.10
tenacity==8.2.2
peft
# TODO remove pycocoevalcap
pycocoevalcap
cpm_kernels
umap-learn

View File

@ -319,6 +319,14 @@ def all_datasource_requires():
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
def openai_requires():
"""
pip install "db-gpt[openai]"
"""
setup_spec.extras["openai"] = ["openai", "tiktoken"]
llama_cpp_python_cuda_requires()
def all_requires():
requires = set()
for _, pkgs in setup_spec.extras.items():
@ -339,6 +347,7 @@ llama_cpp_requires()
quantization_requires()
all_vector_store_requires()
all_datasource_requires()
openai_requires()
# must be last
all_requires()