mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 00:03:29 +00:00
feat(model): support openai embedding model
This commit is contained in:
parent
e331936bff
commit
5aa9cb455e
@ -61,6 +61,12 @@ KNOWLEDGE_SEARCH_TOP_SIZE=5
|
|||||||
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
# EMBEDDING_TOKENIZER=all-MiniLM-L6-v2
|
||||||
# EMBEDDING_TOKEN_LIMIT=8191
|
# EMBEDDING_TOKEN_LIMIT=8191
|
||||||
|
|
||||||
|
## Openai embedding model, See /pilot/model/parameter.py
|
||||||
|
# EMBEDDING_MODEL=proxy_openai
|
||||||
|
# proxy_openai_proxy_server_url=https://api.openai.com/v1
|
||||||
|
# proxy_openai_proxy_api_key={your-openai-sk}
|
||||||
|
# proxy_openai_proxy_backend=text-embedding-ada-002
|
||||||
|
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
#** DATABASE SETTINGS **#
|
#** DATABASE SETTINGS **#
|
||||||
|
@ -49,7 +49,7 @@ services:
|
|||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
webserver:
|
webserver:
|
||||||
image: eosphorosai/dbgpt:latest
|
image: eosphorosai/dbgpt:latest
|
||||||
command: dbgpt start webserver --light
|
command: dbgpt start webserver --light --remote_embedding
|
||||||
environment:
|
environment:
|
||||||
- DBGPT_LOG_LEVEL=DEBUG
|
- DBGPT_LOG_LEVEL=DEBUG
|
||||||
- LOCAL_DB_PATH=data/default_sqlite.db
|
- LOCAL_DB_PATH=data/default_sqlite.db
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
import sys
|
||||||
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
from typing import Type, Dict, TypeVar, Optional, Union, TYPE_CHECKING
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
@ -152,8 +153,15 @@ class SystemApp(LifeCycle):
|
|||||||
@self.app.on_event("startup")
|
@self.app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
"""ASGI app startup event handler."""
|
"""ASGI app startup event handler."""
|
||||||
# TODO catch exception and shutdown if worker manager start failed
|
|
||||||
asyncio.create_task(self.async_after_start())
|
async def _startup_func():
|
||||||
|
try:
|
||||||
|
await self.async_after_start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting system app: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
asyncio.create_task(_startup_func())
|
||||||
self.after_start()
|
self.after_start()
|
||||||
|
|
||||||
@self.app.on_event("shutdown")
|
@self.app.on_event("shutdown")
|
||||||
|
@ -128,8 +128,8 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
### default Local database connection configuration
|
### default Local database connection configuration
|
||||||
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
self.LOCAL_DB_HOST = os.getenv("LOCAL_DB_HOST")
|
||||||
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "")
|
self.LOCAL_DB_PATH = os.getenv("LOCAL_DB_PATH", "data/default_sqlite.db")
|
||||||
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "mysql")
|
self.LOCAL_DB_TYPE = os.getenv("LOCAL_DB_TYPE", "sqlite")
|
||||||
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
|
if self.LOCAL_DB_HOST is None and self.LOCAL_DB_PATH == "":
|
||||||
self.LOCAL_DB_HOST = "127.0.0.1"
|
self.LOCAL_DB_HOST = "127.0.0.1"
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.LOCAL_DB_MANAGE = None
|
self.LOCAL_DB_MANAGE = None
|
||||||
|
|
||||||
### LLM Model Service Configuration
|
### LLM Model Service Configuration
|
||||||
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b")
|
self.LLM_MODEL = os.getenv("LLM_MODEL", "vicuna-13b-v1.5")
|
||||||
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
### Proxy llm backend, this configuration is only valid when "LLM_MODEL=proxyllm"
|
||||||
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
### When we use the rest API provided by deployment frameworks like fastchat as a proxyllm, "PROXYLLM_BACKEND" is the model they actually deploy.
|
||||||
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
### We need to use "PROXYLLM_BACKEND" to load the prompt of the corresponding scene.
|
||||||
|
@ -88,6 +88,8 @@ EMBEDDING_MODEL_CONFIG = {
|
|||||||
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
"bge-large-zh": os.path.join(MODEL_PATH, "bge-large-zh"),
|
||||||
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
"bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"),
|
||||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||||
|
"proxy_openai": "proxy_openai",
|
||||||
|
"proxy_azure": "proxy_azure",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Type, TYPE_CHECKING
|
from typing import Any, Type, TYPE_CHECKING
|
||||||
|
|
||||||
|
@ -96,7 +96,11 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
|||||||
|
|
||||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||||
from pilot.model.parameter import EmbeddingModelParameters, WorkerType
|
from pilot.model.parameter import (
|
||||||
|
EmbeddingModelParameters,
|
||||||
|
WorkerType,
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
|
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
|
||||||
pre_args.parse()
|
pre_args.parse()
|
||||||
@ -106,7 +110,11 @@ def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
|||||||
if model_name is None:
|
if model_name is None:
|
||||||
return None
|
return None
|
||||||
if worker_type == WorkerType.TEXT2VEC:
|
if worker_type == WorkerType.TEXT2VEC:
|
||||||
return [EmbeddingModelParameters]
|
return [
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
|
model_name, EmbeddingModelParameters
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||||
param_class = llm_adapter.model_param_class()
|
param_class = llm_adapter.model_param_class()
|
||||||
|
27
pilot/model/cluster/embedding/loader.py
Normal file
27
pilot/model/cluster/embedding/loader.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pilot.model.parameter import BaseEmbeddingModelParameters
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain.embeddings.base import Embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingLoader:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||||
|
) -> "Embeddings":
|
||||||
|
# add more models
|
||||||
|
if model_name in ["proxy_openai", "proxy_azure"]:
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
return OpenAIEmbeddings(**param.build_kwargs())
|
||||||
|
else:
|
||||||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||||
|
return HuggingFaceEmbeddings(**kwargs)
|
@ -5,13 +5,16 @@ from pilot.configs.model_config import get_device
|
|||||||
from pilot.model.loader import _get_model_real_path
|
from pilot.model.loader import _get_model_real_path
|
||||||
from pilot.model.parameter import (
|
from pilot.model.parameter import (
|
||||||
EmbeddingModelParameters,
|
EmbeddingModelParameters,
|
||||||
|
BaseEmbeddingModelParameters,
|
||||||
WorkerType,
|
WorkerType,
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
)
|
)
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_torch_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger("model_worker")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsModelWorker(ModelWorker):
|
class EmbeddingsModelWorker(ModelWorker):
|
||||||
@ -26,6 +29,9 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
) from exc
|
) from exc
|
||||||
self._embeddings_impl: Embeddings = None
|
self._embeddings_impl: Embeddings = None
|
||||||
self._model_params = None
|
self._model_params = None
|
||||||
|
self.model_name = None
|
||||||
|
self.model_path = None
|
||||||
|
self._loader = EmbeddingLoader()
|
||||||
|
|
||||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||||
if model_path.endswith("/"):
|
if model_path.endswith("/"):
|
||||||
@ -39,11 +45,13 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
return WorkerType.TEXT2VEC
|
return WorkerType.TEXT2VEC
|
||||||
|
|
||||||
def model_param_class(self) -> Type:
|
def model_param_class(self) -> Type:
|
||||||
return EmbeddingModelParameters
|
return EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
|
self.model_name, EmbeddingModelParameters
|
||||||
|
)
|
||||||
|
|
||||||
def parse_parameters(
|
def parse_parameters(
|
||||||
self, command_args: List[str] = None
|
self, command_args: List[str] = None
|
||||||
) -> EmbeddingModelParameters:
|
) -> BaseEmbeddingModelParameters:
|
||||||
param_cls = self.model_param_class()
|
param_cls = self.model_param_class()
|
||||||
return _parse_embedding_params(
|
return _parse_embedding_params(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
@ -58,15 +66,10 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
command_args: List[str] = None,
|
command_args: List[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start model worker"""
|
"""Start model worker"""
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
if not model_params:
|
if not model_params:
|
||||||
model_params = self.parse_parameters(command_args)
|
model_params = self.parse_parameters(command_args)
|
||||||
self._model_params = model_params
|
self._model_params = model_params
|
||||||
|
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||||
kwargs = model_params.build_kwargs(model_name=model_params.model_path)
|
|
||||||
logger.info(f"Start HuggingFaceEmbeddings with kwargs: {kwargs}")
|
|
||||||
self._embeddings_impl = HuggingFaceEmbeddings(**kwargs)
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.stop()
|
self.stop()
|
||||||
@ -101,7 +104,7 @@ def _parse_embedding_params(
|
|||||||
):
|
):
|
||||||
model_args = EnvArgumentParser()
|
model_args = EnvArgumentParser()
|
||||||
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
env_prefix = EnvArgumentParser.get_env_prefix(model_name)
|
||||||
model_params: EmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
model_params: BaseEmbeddingModelParameters = model_args.parse_args_into_dataclass(
|
||||||
param_cls,
|
param_cls,
|
||||||
env_prefix=env_prefix,
|
env_prefix=env_prefix,
|
||||||
command_args=command_args,
|
command_args=command_args,
|
||||||
|
@ -2,6 +2,7 @@ import asyncio
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@ -129,8 +130,6 @@ class LocalWorkerManager(WorkerManager):
|
|||||||
command_args: List[str] = None,
|
command_args: List[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if not command_args:
|
if not command_args:
|
||||||
import sys
|
|
||||||
|
|
||||||
command_args = sys.argv[1:]
|
command_args = sys.argv[1:]
|
||||||
worker.load_worker(**asdict(worker_params))
|
worker.load_worker(**asdict(worker_params))
|
||||||
|
|
||||||
@ -635,8 +634,15 @@ def _setup_fastapi(worker_params: ModelWorkerParameters, app=None):
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
# TODO catch exception and shutdown if worker manager start failed
|
async def start_worker_manager():
|
||||||
asyncio.create_task(worker_manager.start())
|
try:
|
||||||
|
await worker_manager.start()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting worker manager: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# It cannot be blocked here because the startup of worker_manager depends on the fastapi app (registered to the controller)
|
||||||
|
asyncio.create_task(start_worker_manager())
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
@app.on_event("shutdown")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
|
@ -6,6 +6,9 @@ from pilot.model.parameter import ModelParameters
|
|||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RemoteModelWorker(ModelWorker):
|
class RemoteModelWorker(ModelWorker):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.headers = {}
|
self.headers = {}
|
||||||
@ -46,13 +49,14 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
"""Asynchronous generate stream"""
|
"""Asynchronous generate stream"""
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
delimiter = b"\0"
|
delimiter = b"\0"
|
||||||
buffer = b""
|
buffer = b""
|
||||||
|
url = self.worker_addr + "/generate_stream"
|
||||||
|
logger.debug(f"Send async_generate_stream to url {url}, params: {params}")
|
||||||
async with client.stream(
|
async with client.stream(
|
||||||
"POST",
|
"POST",
|
||||||
self.worker_addr + "/generate_stream",
|
url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -75,10 +79,11 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
"""Asynchronous generate non stream"""
|
"""Asynchronous generate non stream"""
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
logging.debug(f"Send async_generate_stream, params: {params}")
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
url = self.worker_addr + "/generate"
|
||||||
|
logger.debug(f"Send async_generate to url {url}, params: {params}")
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.worker_addr + "/generate",
|
url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -89,8 +94,10 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
"""Get embeddings for input"""
|
"""Get embeddings for input"""
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
url = self.worker_addr + "/embeddings"
|
||||||
|
logger.debug(f"Send embeddings to url {url}, params: {params}")
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.worker_addr + "/embeddings",
|
url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -102,8 +109,10 @@ class RemoteModelWorker(ModelWorker):
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
url = self.worker_addr + "/embeddings"
|
||||||
|
logger.debug(f"Send async_embeddings to url {url}")
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
self.worker_addr + "/embeddings",
|
url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json=params,
|
json=params,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
|
@ -86,7 +86,13 @@ class ModelWorkerParameters(BaseModelParameters):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingModelParameters(BaseModelParameters):
|
class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||||
|
def build_kwargs(self, **kwargs) -> Dict:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||||
device: Optional[str] = field(
|
device: Optional[str] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@ -268,3 +274,81 @@ class ProxyModelParameters(BaseModelParameters):
|
|||||||
max_context_size: Optional[int] = field(
|
max_context_size: Optional[int] = field(
|
||||||
default=4096, metadata={"help": "Maximum context size"}
|
default=4096, metadata={"help": "Maximum context size"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||||
|
proxy_server_url: str = field(
|
||||||
|
metadata={
|
||||||
|
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proxy_api_key: str = field(
|
||||||
|
metadata={
|
||||||
|
"tags": "privacy",
|
||||||
|
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
device: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Device to run model. Not working for proxy embedding model"},
|
||||||
|
)
|
||||||
|
proxy_api_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proxy_api_version: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
proxy_backend: Optional[str] = field(
|
||||||
|
default="text-embedding-ada-002",
|
||||||
|
metadata={
|
||||||
|
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
proxy_deployment: Optional[str] = field(
|
||||||
|
default="text-embedding-ada-002",
|
||||||
|
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def build_kwargs(self, **kwargs) -> Dict:
|
||||||
|
params = {
|
||||||
|
"openai_api_base": self.proxy_server_url,
|
||||||
|
"openai_api_key": self.proxy_api_key,
|
||||||
|
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
|
||||||
|
"openai_api_version": self.proxy_api_version
|
||||||
|
if self.proxy_api_version
|
||||||
|
else None,
|
||||||
|
"model": self.proxy_backend,
|
||||||
|
"deployment": self.proxy_deployment
|
||||||
|
if self.proxy_deployment
|
||||||
|
else self.proxy_backend,
|
||||||
|
}
|
||||||
|
for k, v in kwargs:
|
||||||
|
params[k] = v
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||||
|
ProxyEmbeddingParameters: "proxy_openai,proxy_azure"
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _update_embedding_config():
|
||||||
|
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
||||||
|
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
||||||
|
models = [m.strip() for m in models.split(",")]
|
||||||
|
for model in models:
|
||||||
|
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
||||||
|
|
||||||
|
|
||||||
|
_update_embedding_config()
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Type, TYPE_CHECKING
|
from typing import Any, Type, TYPE_CHECKING
|
||||||
|
|
||||||
from pilot.componet import SystemApp
|
from pilot.componet import SystemApp
|
||||||
@ -8,7 +10,6 @@ from pilot.embedding_engine.embedding_factory import (
|
|||||||
DefaultEmbeddingFactory,
|
DefaultEmbeddingFactory,
|
||||||
)
|
)
|
||||||
from pilot.server.base import WebWerverParameters
|
from pilot.server.base import WebWerverParameters
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
@ -46,18 +47,11 @@ def _initialize_embedding_model(
|
|||||||
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
RemoteEmbeddingFactory, worker_manager, model_name=embedding_model_name
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from pilot.model.parameter import EmbeddingModelParameters
|
logger.info(f"Register local LocalEmbeddingFactory")
|
||||||
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
|
||||||
|
|
||||||
model_params: EmbeddingModelParameters = _parse_embedding_params(
|
|
||||||
model_name=embedding_model_name,
|
|
||||||
model_path=embedding_model_path,
|
|
||||||
param_cls=EmbeddingModelParameters,
|
|
||||||
)
|
|
||||||
kwargs = model_params.build_kwargs(model_name=embedding_model_path)
|
|
||||||
logger.info(f"Register local DefaultEmbeddingFactory with kwargs: {kwargs}")
|
|
||||||
system_app.register(
|
system_app.register(
|
||||||
DefaultEmbeddingFactory, default_model_name=embedding_model_path, **kwargs
|
LocalEmbeddingFactory,
|
||||||
|
default_model_name=embedding_model_name,
|
||||||
|
default_model_path=embedding_model_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -82,3 +76,51 @@ class RemoteEmbeddingFactory(EmbeddingFactory):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
# Ignore model_name args
|
# Ignore model_name args
|
||||||
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
return RemoteEmbeddings(self._default_model_name, self._worker_manager)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalEmbeddingFactory(EmbeddingFactory):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
system_app,
|
||||||
|
default_model_name: str = None,
|
||||||
|
default_model_path: str = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(system_app=system_app)
|
||||||
|
self._default_model_name = default_model_name
|
||||||
|
self._default_model_path = default_model_path
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self._model = self._load_model()
|
||||||
|
|
||||||
|
def init_app(self, system_app):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self, model_name: str = None, embedding_cls: Type = None
|
||||||
|
) -> "Embeddings":
|
||||||
|
if embedding_cls:
|
||||||
|
raise NotImplementedError
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def _load_model(self) -> "Embeddings":
|
||||||
|
from pilot.model.parameter import (
|
||||||
|
EmbeddingModelParameters,
|
||||||
|
BaseEmbeddingModelParameters,
|
||||||
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||||
|
)
|
||||||
|
from pilot.model.cluster.worker.embedding_worker import _parse_embedding_params
|
||||||
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
|
|
||||||
|
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||||
|
self._default_model_name, EmbeddingModelParameters
|
||||||
|
)
|
||||||
|
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||||
|
model_name=self._default_model_name,
|
||||||
|
model_path=self._default_model_path,
|
||||||
|
param_cls=param_cls,
|
||||||
|
**self._kwargs,
|
||||||
|
)
|
||||||
|
logger.info(model_params)
|
||||||
|
loader = EmbeddingLoader()
|
||||||
|
# Ignore model_name args
|
||||||
|
return loader.load(self._default_model_name, model_params)
|
||||||
|
@ -37,6 +37,7 @@ opencv-python==4.7.0.72
|
|||||||
iopath==0.1.10
|
iopath==0.1.10
|
||||||
tenacity==8.2.2
|
tenacity==8.2.2
|
||||||
peft
|
peft
|
||||||
|
# TODO remove pycocoevalcap
|
||||||
pycocoevalcap
|
pycocoevalcap
|
||||||
cpm_kernels
|
cpm_kernels
|
||||||
umap-learn
|
umap-learn
|
||||||
|
9
setup.py
9
setup.py
@ -319,6 +319,14 @@ def all_datasource_requires():
|
|||||||
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
|
setup_spec.extras["datasource"] = ["pymssql", "pymysql"]
|
||||||
|
|
||||||
|
|
||||||
|
def openai_requires():
|
||||||
|
"""
|
||||||
|
pip install "db-gpt[openai]"
|
||||||
|
"""
|
||||||
|
setup_spec.extras["openai"] = ["openai", "tiktoken"]
|
||||||
|
llama_cpp_python_cuda_requires()
|
||||||
|
|
||||||
|
|
||||||
def all_requires():
|
def all_requires():
|
||||||
requires = set()
|
requires = set()
|
||||||
for _, pkgs in setup_spec.extras.items():
|
for _, pkgs in setup_spec.extras.items():
|
||||||
@ -339,6 +347,7 @@ llama_cpp_requires()
|
|||||||
quantization_requires()
|
quantization_requires()
|
||||||
all_vector_store_requires()
|
all_vector_store_requires()
|
||||||
all_datasource_requires()
|
all_datasource_requires()
|
||||||
|
openai_requires()
|
||||||
|
|
||||||
# must be last
|
# must be last
|
||||||
all_requires()
|
all_requires()
|
||||||
|
Loading…
Reference in New Issue
Block a user