feat(model): support openai embedding model (#591)

Close #578
This commit is contained in:
Aries-ckt 2023-09-15 21:34:41 +08:00 committed by GitHub
commit 62646d143d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 269 additions and 42 deletions

View File

@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191
## Openai embedding model, See /pilot/model/parameter.py
# EMBEDDING_MODEL=proxy_openai
# proxy_openai_proxy_server_url=https://api.openai.com/v1
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
#*******************************************************************#
#** DATABASE SETTINGS **#

View File

@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
- Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)

View File

@ -119,6 +119,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持
- 支持多种大语言模型, 当前已支持如下模型:
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)

View File

@ -49,7 +49,7 @@ services:
capabilities: [gpu]
webserver:
image: eosphorosai/dbgpt:latest
command: dbgpt start webserver --light
command: dbgpt start webserver --light --remote_embedding
environment:
- DBGPT_LOG_LEVEL=DEBUG
- LOCAL_DB_PATH=data/default_sqlite.db

View File

@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum
import logging
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
@self.app.on_event("startup")
async def startup_event():
"""ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start())
async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start()
@self.app.on_event("shutdown")

View File

@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1"
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_MANAGE = None
### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.

View File

@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
}
# Load model config

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING

View File

@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
if model_name is None:
return None
if worker_type == WorkerType.TEXT2VEC:
return [EmbeddingModelParameters]
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingLoader:
def __init__(self) -> None:
pass
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker")
logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
) from exc
self._embeddings_impl: Embeddings = None
self._model_params = None
self.model_name = None
self.model_path = None
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
return WorkerType.TEXT2VEC
def model_param_class(self) -> Type:
return EmbeddingModelParameters
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self.model_name, EmbeddingModelParameters
)
def parse_parameters(
self, command_args: List[str] = None
) -> EmbeddingModelParameters:
) -> BaseEmbeddingModelParameters:
param_cls = self.model_param_class()
return _parse_embedding_params(
model_name=self.model_name,
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
command_args: List[str] = None,
) -> None:
"""Start model worker"""
from langchain.embeddings import HuggingFaceEmbeddings
if not model_params:
model_params = self.parse_parameters(command_args)
self._model_params = model_params
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
self._embeddings_impl = self._loader.load(self.model_name, model_params)
def __del__(self):
self.stop()
@ -101,7 +104,7 @@ def _parse_embedding_params(
):
model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls,
env_prefix=env_prefix,
command_args=command_args,

View File

@ -2,6 +2,7 @@ import asyncio
import itertools
import json
import os
import sys
import random
import time
from concurrent.futures import ThreadPoolExecutor
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
command_args: List[str] = None,
) -> bool:
if not command_args:
import sys
command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params))
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup")
async def startup_event():
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(worker_manager.start())
async def start_worker_manager():
try:
await worker_manager.start()
except Exception as e:
logger.error(f"Error starting worker manager: {e}")
sys.exit(1)
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
@app.on_event("shutdown")
async def startup_event():

View File

@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
logger = logging.getLogger(__name__)
class RemoteModelWorker(ModelWorker):
def __init__(self) -> None:
self.headers = {}
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate stream"""
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
delimiter = b"\0"
buffer = b""
url = self.worker_addr + "/generate_stream"
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
async with client.stream(
"POST",
self.worker_addr + "/generate_stream",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate non stream"""
import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/generate"
logger.debug(f"Send async_generate to url {url}, params: {params}")
response = await client.post(
self.worker_addr + "/generate",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
"""Get embeddings for input"""
import requests
url = self.worker_addr + "/embeddings"
logger.debug(f"Send embeddings to url {url}, params: {params}")
response = requests.post(
self.worker_addr + "/embeddings",
url,
headers=self.headers,
json=params,
timeout=self.timeout,
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
import httpx
async with httpx.AsyncClient() as client:
url = self.worker_addr + "/embeddings"
logger.debug(f"Send async_embeddings to url {url}")
response = await client.post(
self.worker_addr + "/embeddings",
url,
headers=self.headers,
json=params,
timeout=self.timeout,

View File

@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
@dataclass
class EmbeddingModelParameters(BaseModelParameters):
class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
device: Optional[str] = field(
default=None,
metadata={
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
max_context_size: Optional[int] = field(
default=4096, metadata={"help": "Maximum context size"}
)
@dataclass
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
proxy_server_url: str = field(
metadata={
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
},
)
proxy_api_key: str = field(
metadata={
"tags": "privacy",
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
},
)
device: Optional[str] = field(
default=None,
metadata={"help": "Device to run model. Not working for proxy embedding model"},
)
proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
},
)
proxy_backend: Optional[str] = field(
default="text-embedding-ada-002",
metadata={
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
},
)
proxy_deployment: Optional[str] = field(
default="text-embedding-ada-002",
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
}
for k, v in kwargs:
params[k] = v
return params
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
_update_embedding_config()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
DefaultEmbeddingFactory,
)
from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -46,18 +47,11 @@ def _initialize_embedding_model(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
)
else:
from pilot.model.parameter import EmbeddingModelParameters
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
logger.info(f"Register local LocalEmbeddingFactory")
system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
)
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
raise NotImplementedError
# Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@ -37,6 +37,7 @@ opencv-python==4.7.0.72
iopath==0.1.10
tenacity==8.2.2
peft
# TODO remove pycocoevalcap
pycocoevalcap
cpm_kernels
umap-learn
@ -57,7 +58,7 @@ seaborn
auto-gpt-plugin-template
pymdown-extensions
gTTS==2.3.1
langchain
langchain>=0.0.286
nltk
python-dotenv==1.0.0

View File

@ -0,0 +1,19 @@
#!/bin/bash
eval "$(conda shell.bash hook)"
source ~/.bashrc
# source /etc/network_turbo
# unset http_proxy && unset https_proxy
conda create -n dbgpt python=3.10 -y
conda activate dbgpt
apt-get update -y && apt-get install git-lfs -y
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b-int4

View File

@ -319,6 +319,13 @@ def all_datasource_requires():
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
def openai_requires():
"""
pip install "db-gpt[openai]"
"""
setup_spec.extras["openai"] = ["openai", "tiktoken"]
def all_requires():
requires = set()
for _, pkgs in setup_spec.extras.items():
@ -339,6 +346,7 @@ llama_cpp_requires()
quantization_requires()
all_vector_store_requires()
all_datasource_requires()
openai_requires()
# must be last
all_requires()