mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-30 15:21:02 +00:00
feat(model): support openai embedding model
This commit is contained in:
parent
e331936bff
commit
5aa9cb455e
@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
|
||||
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
||||
# EMBEDDING_TOKEN_LIMIT=8191
|
||||
|
||||
## Openai embedding model, See /pilot/model/parameter.py
|
||||
# EMBEDDING_MODEL=proxy_openai
|
||||
# proxy_openai_proxy_server_url=https://api.openai.com/v1
|
||||
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
#** DATABASE SETTINGS **#
|
||||
|
@ -49,7 +49,7 @@ services:
|
||||
capabilities: [gpu]
|
||||
webserver:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start webserver --light
|
||||
command: dbgpt start webserver --light --remote_embedding
|
||||
environment:
|
||||
- DBGPT_LOG_LEVEL=DEBUG
|
||||
- LOCAL_DB_PATH=data/default_sqlite.db
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
||||
from enum import Enum
|
||||
import logging
|
||||
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
|
||||
@self.app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""ASGI app startup event handler."""
|
||||
# TODO catch exception and shutdown if worker manager start failed
|
||||
asyncio.create_task(self.async_after_start())
|
||||
|
||||
async def _startup_func():
|
||||
try:
|
||||
await self.async_after_start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting system app: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
asyncio.create_task(_startup_func())
|
||||
self.after_start()
|
||||
|
||||
@self.app.on_event("shutdown")
|
||||
|
@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### default Local database connection configuration
|
||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
|
||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
|
||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
|
||||
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
|
||||
self.LOCAL_DB_HOST = "127.0.0.1"
|
||||
|
||||
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
|
||||
self.LOCAL_DB_MANAGE = None
|
||||
|
||||
### LLM Model Service Configuration
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
|
||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||
|
@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"proxy_openai": "proxy_openai",
|
||||
"proxy_azure": "proxy_azure",
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -1,3 +1,4 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
|
@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
|
||||
pre_args.parse()
|
||||
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
if model_name is None:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [EmbeddingModelParameters]
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
|
27
pilot/model/cluster/embedding/loader.py
Normal file
27
pilot/model/cluster/embedding/loader.py
Normal file
@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pilot.model.parameter import BaseEmbeddingModelParameters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class EmbeddingLoader:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def load(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> "Embeddings":
|
||||
# add more models
|
||||
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(**param.build_kwargs())
|
||||
else:
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
BaseEmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger("model_worker")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
) from exc
|
||||
self._embeddings_impl: Embeddings = None
|
||||
self._model_params = None
|
||||
self.model_name = None
|
||||
self.model_path = None
|
||||
self._loader = EmbeddingLoader()
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
if model_path.endswith("/"):
|
||||
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
return WorkerType.TEXT2VEC
|
||||
|
||||
def model_param_class(self) -> Type:
|
||||
return EmbeddingModelParameters
|
||||
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self.model_name, EmbeddingModelParameters
|
||||
)
|
||||
|
||||
def parse_parameters(
|
||||
self, command_args: List[str] = None
|
||||
) -> EmbeddingModelParameters:
|
||||
) -> BaseEmbeddingModelParameters:
|
||||
param_cls = self.model_param_class()
|
||||
return _parse_embedding_params(
|
||||
model_name=self.model_name,
|
||||
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
command_args: List[str] = None,
|
||||
) -> None:
|
||||
"""Start model worker"""
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
|
||||
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
|
||||
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
|
||||
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
|
||||
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
@ -101,7 +104,7 @@ def _parse_embedding_params(
|
||||
):
|
||||
model_args = EnvArgumentParser()
|
||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||
param_cls,
|
||||
env_prefix=env_prefix,
|
||||
command_args=command_args,
|
||||
|
@ -2,6 +2,7 @@ import asyncio
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
|
||||
command_args: List[str] = None,
|
||||
) -> bool:
|
||||
if not command_args:
|
||||
import sys
|
||||
|
||||
command_args = sys.argv[1:]
|
||||
worker.load_worker(**asdict(worker_params))
|
||||
|
||||
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# TODO catch exception and shutdown if worker manager start failed
|
||||
asyncio.create_task(worker_manager.start())
|
||||
async def start_worker_manager():
|
||||
try:
|
||||
await worker_manager.start()
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting worker manager: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
|
||||
asyncio.create_task(start_worker_manager())
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def startup_event():
|
||||
|
@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
self.headers = {}
|
||||
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
|
||||
"""Asynchronous generate stream"""
|
||||
import httpx
|
||||
|
||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
delimiter = b"\0"
|
||||
buffer = b""
|
||||
url = self.worker_addr + "/generate_stream"
|
||||
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
|
||||
async with client.stream(
|
||||
"POST",
|
||||
self.worker_addr + "/generate_stream",
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
|
||||
"""Asynchronous generate non stream"""
|
||||
import httpx
|
||||
|
||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = self.worker_addr + "/generate"
|
||||
logger.debug(f"Send async_generate to url {url}, params: {params}")
|
||||
response = await client.post(
|
||||
self.worker_addr + "/generate",
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
|
||||
"""Get embeddings for input"""
|
||||
import requests
|
||||
|
||||
url = self.worker_addr + "/embeddings"
|
||||
logger.debug(f"Send embeddings to url {url}, params: {params}")
|
||||
response = requests.post(
|
||||
self.worker_addr + "/embeddings",
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
url = self.worker_addr + "/embeddings"
|
||||
logger.debug(f"Send async_embeddings to url {url}")
|
||||
response = await client.post(
|
||||
self.worker_addr + "/embeddings",
|
||||
url,
|
||||
headers=self.headers,
|
||||
json=params,
|
||||
timeout=self.timeout,
|
||||
|
@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseModelParameters):
|
||||
class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
max_context_size: Optional[int] = field(
|
||||
default=4096, metadata={"help": "Maximum context size"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
proxy_server_url: str = field(
|
||||
metadata={
|
||||
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
|
||||
},
|
||||
)
|
||||
proxy_api_key: str = field(
|
||||
metadata={
|
||||
"tags": "privacy",
|
||||
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
|
||||
},
|
||||
)
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Device to run model. Not working for proxy embedding model"},
|
||||
)
|
||||
proxy_api_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
||||
},
|
||||
)
|
||||
proxy_api_version: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
|
||||
},
|
||||
)
|
||||
proxy_backend: Optional[str] = field(
|
||||
default="text-embedding-ada-002",
|
||||
metadata={
|
||||
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
|
||||
},
|
||||
)
|
||||
|
||||
proxy_deployment: Optional[str] = field(
|
||||
default="text-embedding-ada-002",
|
||||
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
params = {
|
||||
"openai_api_base": self.proxy_server_url,
|
||||
"openai_api_key": self.proxy_api_key,
|
||||
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
|
||||
"openai_api_version": self.proxy_api_version
|
||||
if self.proxy_api_version
|
||||
else None,
|
||||
"model": self.proxy_backend,
|
||||
"deployment": self.proxy_deployment
|
||||
if self.proxy_deployment
|
||||
else self.proxy_backend,
|
||||
}
|
||||
for k, v in kwargs:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
||||
|
||||
def _update_embedding_config():
|
||||
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
||||
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
||||
models = [m.strip() for m in models.split(",")]
|
||||
for model in models:
|
||||
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
||||
|
||||
|
||||
_update_embedding_config()
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Type, TYPE_CHECKING
|
||||
|
||||
from pilot.componet import SystemApp
|
||||
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
|
||||
DefaultEmbeddingFactory,
|
||||
)
|
||||
from pilot.server.base import WebWerverParameters
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -46,18 +47,11 @@ def _initialize_embedding_model(
|
||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||
)
|
||||
else:
|
||||
from pilot.model.parameter import EmbeddingModelParameters
|
||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
|
||||
model_params: EmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=embedding_model_name,
|
||||
model_path=embedding_model_path,
|
||||
param_cls=EmbeddingModelParameters,
|
||||
)
|
||||
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
|
||||
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
|
||||
logger.info(f"Register local LocalEmbeddingFactory")
|
||||
system_app.register(
|
||||
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
|
||||
LocalEmbeddingFactory,
|
||||
default_model_name=embedding_model_name,
|
||||
default_model_path=embedding_model_path,
|
||||
)
|
||||
|
||||
|
||||
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
raise NotImplementedError
|
||||
# Ignore model_name args
|
||||
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
||||
|
||||
|
||||
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "Embeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "Embeddings":
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
BaseEmbeddingModelParameters,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
||||
|
@ -37,6 +37,7 @@ opencv-python==4.7.0.72
|
||||
iopath==0.1.10
|
||||
tenacity==8.2.2
|
||||
peft
|
||||
# TODO remove pycocoevalcap
|
||||
pycocoevalcap
|
||||
cpm_kernels
|
||||
umap-learn
|
||||
|
9
setup.py
9
setup.py
@ -319,6 +319,14 @@ def all_datasource_requires():
|
||||
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
|
||||
|
||||
|
||||
def openai_requires():
|
||||
"""
|
||||
pip install "db-gpt[openai]"
|
||||
"""
|
||||
setup_spec.extras["openai"] = ["openai", "tiktoken"]
|
||||
llama_cpp_python_cuda_requires()
|
||||
|
||||
|
||||
def all_requires():
|
||||
requires = set()
|
||||
for _, pkgs in setup_spec.extras.items():
|
||||
@ -339,6 +347,7 @@ llama_cpp_requires()
|
||||
quantization_requires()
|
||||
all_vector_store_requires()
|
||||
all_datasource_requires()
|
||||
openai_requires()
|
||||
|
||||
# must be last
|
||||
all_requires()
|
||||
|
Loading…
Reference in New Issue
Block a user