feat(model): support openai embedding model (#591)

Close #578
This commit is contained in:
Aries-ckt 2023-09-15 21:34:41 +08:00 committed by GitHub
commit 62646d143d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 269 additions and 42 deletions

View File

@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2 # EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
# EMBEDDING_TOKEN_LIMIT=8191 # EMBEDDING_TOKEN_LIMIT=8191
## Openai embedding model, See /pilot/model/parameter.py
# EMBEDDING_MODEL=proxy_openai
# proxy_openai_proxy_server_url=https://api.openai.com/v1
# proxy_openai_proxy_api_key={your-openai-sk}
# proxy_openai_proxy_backend=text-embedding-ada-002
#*******************************************************************# #*******************************************************************#
#** DATABASE SETTINGS **# #** DATABASE SETTINGS **#

View File

@ -86,6 +86,7 @@ Currently, we have released multiple key features, which are listed below to dem
- Unified vector storage/indexing of knowledge base - Unified vector storage/indexing of knowledge base
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL - Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
- Multi LLMs Support, Supports multiple large language models, currently supporting - Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b) - 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b) - 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b) - 🔥 llama-2(7b,13b,70b)

View File

@ -119,6 +119,7 @@ DB-GPT 是一个开源的以数据库为基础的GPT实验项目使用本地
- 非结构化数据支持包括PDF、MarkDown、CSV、WebURL - 非结构化数据支持包括PDF、MarkDown、CSV、WebURL
- 多模型支持 - 多模型支持
- 支持多种大语言模型, 当前已支持如下模型: - 支持多种大语言模型, 当前已支持如下模型:
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b) - 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b) - 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b) - 🔥 llama-2(7b,13b,70b)

View File

@ -49,7 +49,7 @@ services:
capabilities: [gpu] capabilities: [gpu]
webserver: webserver:
image: eosphorosai/dbgpt:latest image: eosphorosai/dbgpt:latest
command: dbgpt start webserver --light command: dbgpt start webserver --light --remote_embedding
environment: environment:
- DBGPT_LOG_LEVEL=DEBUG - DBGPT_LOG_LEVEL=DEBUG
- LOCAL_DB_PATH=data/default_sqlite.db - LOCAL_DB_PATH=data/default_sqlite.db

View File

@ -6,6 +6,7 @@ DB-GPT provides a management and deployment solution for multiple models. This c
Multi LLMs Support, Supports multiple large language models, currently supporting Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 InternLM(7b)
- 🔥 Baichuan2(7b,13b) - 🔥 Baichuan2(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b) - 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b) - 🔥 llama-2(7b,13b,70b)

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
from enum import Enum from enum import Enum
import logging import logging
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
@self.app.on_event("startup") @self.app.on_event("startup")
async def startup_event(): async def startup_event():
"""ASGI app startup event handler.""" """ASGI app startup event handler."""
# TODO catch exception and shutdown if worker manager start failed
asyncio.create_task(self.async_after_start()) async def _startup_func():
try:
await self.async_after_start()
except Exception as e:
logger.error(f"Error starting system app: {e}")
sys.exit(1)
asyncio.create_task(_startup_func())
self.after_start() self.after_start()
@self.app.on_event("shutdown") @self.app.on_event("shutdown")

View File

@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
### default Local database connection configuration ### default Local database connection configuration
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST") self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "") self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql") self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "": if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
self.LOCAL_DB_HOST = "127.0.0.1" self.LOCAL_DB_HOST = "127.0.0.1"
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
self.LOCAL_DB_MANAGE = None self.LOCAL_DB_MANAGE = None
### LLM Model Service Configuration ### LLM Model Service Configuration
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b") self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm" ### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy. ### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene. ### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.

View File

@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"), "bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"proxy_openai": "proxy_openai",
"proxy_azure": "proxy_azure",
} }
# Load model config # Load model config

View File

@ -1,3 +1,4 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING

View File

@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
def _dynamic_model_parser() -> Callable[[None], List[Type]]: def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import EmbeddingModelParameters, WorkerType from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type") pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse() pre_args.parse()
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
if model_name is None: if model_name is None:
return None return None
if worker_type == WorkerType.TEXT2VEC: if worker_type == WorkerType.TEXT2VEC:
return [EmbeddingModelParameters] return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path) llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class() param_class = llm_adapter.model_param_class()

View File

@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from pilot.model.parameter import BaseEmbeddingModelParameters
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
class EmbeddingLoader:
def __init__(self) -> None:
pass
def load(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> "Embeddings":
# add more models
if model_name in ["proxy_openai", "proxy_azure"]:
from langchain.embeddings import OpenAIEmbeddings
return OpenAIEmbeddings(**param.build_kwargs())
else:
from langchain.embeddings import HuggingFaceEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)

View File

@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
from pilot.model.loader import _get_model_real_path from pilot.model.loader import _get_model_real_path
from pilot.model.parameter import ( from pilot.model.parameter import (
EmbeddingModelParameters, EmbeddingModelParameters,
BaseEmbeddingModelParameters,
WorkerType, WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
) )
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.parameter_utils import EnvArgumentParser from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger("model_worker") logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker): class EmbeddingsModelWorker(ModelWorker):
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
) from exc ) from exc
self._embeddings_impl: Embeddings = None self._embeddings_impl: Embeddings = None
self._model_params = None self._model_params = None
self.model_name = None
self.model_path = None
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"): if model_path.endswith("/"):
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
return WorkerType.TEXT2VEC return WorkerType.TEXT2VEC
def model_param_class(self) -> Type: def model_param_class(self) -> Type:
return EmbeddingModelParameters return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self.model_name, EmbeddingModelParameters
)
def parse_parameters( def parse_parameters(
self, command_args: List[str] = None self, command_args: List[str] = None
) -> EmbeddingModelParameters: ) -> BaseEmbeddingModelParameters:
param_cls = self.model_param_class() param_cls = self.model_param_class()
return _parse_embedding_params( return _parse_embedding_params(
model_name=self.model_name, model_name=self.model_name,
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
command_args: List[str] = None, command_args: List[str] = None,
) -> None: ) -> None:
"""Start model worker""" """Start model worker"""
from langchain.embeddings import HuggingFaceEmbeddings
if not model_params: if not model_params:
model_params = self.parse_parameters(command_args) model_params = self.parse_parameters(command_args)
self._model_params = model_params self._model_params = model_params
self._embeddings_impl = self._loader.load(self.model_name, model_params)
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
def __del__(self): def __del__(self):
self.stop() self.stop()
@ -101,7 +104,7 @@ def _parse_embedding_params(
): ):
model_args = EnvArgumentParser() model_args = EnvArgumentParser()
env_prefix = EnvArgumentParser.get_env_prefix(model_name) env_prefix = EnvArgumentParser.get_env_prefix(model_name)
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass( model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
param_cls, param_cls,
env_prefix=env_prefix, env_prefix=env_prefix,
command_args=command_args, command_args=command_args,

View File

@ -2,6 +2,7 @@ import asyncio
import itertools import itertools
import json import json
import os import os
import sys
import random import random
import time import time
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
command_args: List[str] = None, command_args: List[str] = None,
) -> bool: ) -> bool:
if not command_args: if not command_args:
import sys
command_args = sys.argv[1:] command_args = sys.argv[1:]
worker.load_worker(**asdict(worker_params)) worker.load_worker(**asdict(worker_params))
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
@app.on_event("startup") @app.on_event("startup")
async def startup_event(): async def startup_event():
# TODO catch exception and shutdown if worker manager start failed async def start_worker_manager():
asyncio.create_task(worker_manager.start()) try:
await worker_manager.start()
except Exception as e:
logger.error(f"Error starting worker manager: {e}")
sys.exit(1)
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
asyncio.create_task(start_worker_manager())
@app.on_event("shutdown") @app.on_event("shutdown")
async def startup_event(): async def startup_event():

View File

@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
logger = logging.getLogger(__name__)
class RemoteModelWorker(ModelWorker): class RemoteModelWorker(ModelWorker):
def __init__(self) -> None: def __init__(self) -> None:
self.headers = {} self.headers = {}
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate stream""" """Asynchronous generate stream"""
import httpx import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
delimiter = b"\0" delimiter = b"\0"
buffer = b"" buffer = b""
url = self.worker_addr + "/generate_stream"
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
async with client.stream( async with client.stream(
"POST", "POST",
self.worker_addr + "/generate_stream", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
"""Asynchronous generate non stream""" """Asynchronous generate non stream"""
import httpx import httpx
logging.debug(f"Send async_generate_stream, params: {params}")
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
url = self.worker_addr + "/generate"
logger.debug(f"Send async_generate to url {url}, params: {params}")
response = await client.post( response = await client.post(
self.worker_addr + "/generate", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
"""Get embeddings for input""" """Get embeddings for input"""
import requests import requests
url = self.worker_addr + "/embeddings"
logger.debug(f"Send embeddings to url {url}, params: {params}")
response = requests.post( response = requests.post(
self.worker_addr + "/embeddings", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
import httpx import httpx
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
url = self.worker_addr + "/embeddings"
logger.debug(f"Send async_embeddings to url {url}")
response = await client.post( response = await client.post(
self.worker_addr + "/embeddings", url,
headers=self.headers, headers=self.headers,
json=params, json=params,
timeout=self.timeout, timeout=self.timeout,

View File

@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
@dataclass @dataclass
class EmbeddingModelParameters(BaseModelParameters): class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
device: Optional[str] = field( device: Optional[str] = field(
default=None, default=None,
metadata={ metadata={
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
max_context_size: Optional[int] = field( max_context_size: Optional[int] = field(
default=4096, metadata={"help": "Maximum context size"} default=4096, metadata={"help": "Maximum context size"}
) )
@dataclass
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
proxy_server_url: str = field(
metadata={
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
},
)
proxy_api_key: str = field(
metadata={
"tags": "privacy",
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
},
)
device: Optional[str] = field(
default=None,
metadata={"help": "Device to run model. Not working for proxy embedding model"},
)
proxy_api_type: Optional[str] = field(
default=None,
metadata={
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
},
)
proxy_api_version: Optional[str] = field(
default=None,
metadata={
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
},
)
proxy_backend: Optional[str] = field(
default="text-embedding-ada-002",
metadata={
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
},
)
proxy_deployment: Optional[str] = field(
default="text-embedding-ada-002",
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
}
for k, v in kwargs:
params[k] = v
return params
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
def _update_embedding_config():
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
models = [m.strip() for m in models.split(",")]
for model in models:
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
_update_embedding_config()

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from typing import Any, Type, TYPE_CHECKING from typing import Any, Type, TYPE_CHECKING
from pilot.componet import SystemApp from pilot.componet import SystemApp
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
DefaultEmbeddingFactory, DefaultEmbeddingFactory,
) )
from pilot.server.base import WebWerverParameters from pilot.server.base import WebWerverParameters
from pilot.utils.parameter_utils import EnvArgumentParser
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
@ -46,18 +47,11 @@ def _initialize_embedding_model(
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
) )
else: else:
from pilot.model.parameter import EmbeddingModelParameters logger.info(f"Register local LocalEmbeddingFactory")
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
model_params: EmbeddingModelParameters = _parse_embedding_params(
model_name=embedding_model_name,
model_path=embedding_model_path,
param_cls=EmbeddingModelParameters,
)
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
system_app.register( system_app.register(
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs LocalEmbeddingFactory,
default_model_name=embedding_model_name,
default_model_path=embedding_model_path,
) )
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
raise NotImplementedError raise NotImplementedError
# Ignore model_name args # Ignore model_name args
return RemoteEmbeddings(self._default_model_name, self._worker_manager) return RemoteEmbeddings(self._default_model_name, self._worker_manager)
class LocalEmbeddingFactory(EmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "Embeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "Embeddings":
from pilot.model.parameter import (
EmbeddingModelParameters,
BaseEmbeddingModelParameters,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
from pilot.model.cluster.embedding.loader import EmbeddingLoader
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)

View File

@ -37,6 +37,7 @@ opencv-python==4.7.0.72
iopath==0.1.10 iopath==0.1.10
tenacity==8.2.2 tenacity==8.2.2
peft peft
# TODO remove pycocoevalcap
pycocoevalcap pycocoevalcap
cpm_kernels cpm_kernels
umap-learn umap-learn
@ -57,7 +58,7 @@ seaborn
auto-gpt-plugin-template auto-gpt-plugin-template
pymdown-extensions pymdown-extensions
gTTS==2.3.1 gTTS==2.3.1
langchain langchain>=0.0.286
nltk nltk
python-dotenv==1.0.0 python-dotenv==1.0.0

View File

@ -0,0 +1,19 @@
#!/bin/bash
eval "$(conda shell.bash hook)"
source ~/.bashrc
# source /etc/network_turbo
# unset http_proxy && unset https_proxy
conda create -n dbgpt python=3.10 -y
conda activate dbgpt
apt-get update -y && apt-get install git-lfs -y
cd /root && git clone https://github.com/eosphoros-ai/DB-GPT.git
mkdir -p /root/DB-GPT/models && cd /root/DB-GPT/models
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
git clone https://huggingface.co/THUDM/chatglm2-6b-int4

View File

@ -319,6 +319,13 @@ def all_datasource_requires():
setup_spec.extras["datasource"] = ["pymssql", "pymysql"] setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
def openai_requires():
"""
pip install "db-gpt[openai]"
"""
setup_spec.extras["openai"] = ["openai", "tiktoken"]
def all_requires(): def all_requires():
requires = set() requires = set()
for _, pkgs in setup_spec.extras.items(): for _, pkgs in setup_spec.extras.items():
@ -339,6 +346,7 @@ llama_cpp_requires()
quantization_requires() quantization_requires()
all_vector_store_requires() all_vector_store_requires()
all_datasource_requires() all_datasource_requires()
openai_requires()
# must be last # must be last
all_requires() all_requires()