feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng 2024-05-16 14:50:16 +08:00 committed by GitHub
parent 559affe87d
commit 593e974405
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 814 additions and 75 deletions

View File

@ -97,6 +97,23 @@ KNOWLEDGE_SEARCH_REWRITE=False
# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924
# proxy_http_openapi_proxy_backend=text2vec
#*******************************************************************#
#** RERANK SETTINGS **#
#*******************************************************************#
## Rerank model
# RERANK_MODEL=bge-reranker-base
## If you not set RERANK_MODEL_PATH, DB-GPT will read the model path from EMBEDDING_MODEL_CONFIG based on the RERANK_MODEL.
# RERANK_MODEL_PATH=
## The number of rerank results to return
# RERANK_TOP_K=3
## Common HTTP rerank model
# RERANK_MODEL=rerank_proxy_http_openapi
# rerank_proxy_http_openapi_proxy_server_url=http://127.0.0.1:8100/api/v1/beta/relevance
# rerank_proxy_http_openapi_proxy_api_key={your-api-key}
# rerank_proxy_http_openapi_proxy_backend=bge-reranker-base
#*******************************************************************#

View File

@ -247,6 +247,10 @@ class Config(metaclass=Singleton):
### EMBEDDING Configuration
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
# Rerank model configuration
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
self.RERANK_TOP_K = int(os.getenv("RERANK_TOP_K", 3))
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))

View File

@ -0,0 +1,45 @@
"""Host tool resource module."""
from ...resource.tool.base import tool
@tool(description="Get current host CPU status.")
def get_current_host_cpu_status() -> str:
"""Get current host CPU status."""
import platform
import psutil
cpu_architecture = platform.machine()
cpu_count_physical = psutil.cpu_count(logical=False)
cpu_count_logical = psutil.cpu_count(logical=True)
cpu_usage = psutil.cpu_percent(interval=1)
return (
f"CPU Architecture: {cpu_architecture}\n"
f"Physical CPU Cores: {cpu_count_physical}\n"
f"Logical CPU Cores: {cpu_count_logical}\n"
f"CPU Usage: {cpu_usage}%"
)
@tool(description="Get current host memory status.")
def get_current_host_memory_status() -> str:
"""Get current host memory status."""
import psutil
memory = psutil.virtual_memory()
return (
f"Total: {memory.total / (1024**3):.2f} GB\n"
f"Available: {memory.available / (1024**3):.2f} GB\n"
f"Used: {memory.used / (1024**3):.2f} GB\n"
f"Percent: {memory.percent}%"
)
@tool(description="Get current host system load.")
def get_current_host_system_load() -> str:
"""Get current host system load."""
import os
load1, load5, load15 = os.getloadavg()
return f"System load average: {load1:.2f}, {load5:.2f}, {load15:.2f}"

View File

@ -227,6 +227,14 @@ class WebServerParameters(BaseParameters):
"text2vec --model_name xxx --model_path xxx`"
},
)
remote_rerank: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to enable remote rerank models. If it is True, you need"
" to start a rerank model through `dbgpt start worker --worker_type "
"text2vec --rerank --model_name xxx --model_path xxx`"
},
)
log_level: Optional[str] = field(
default=None,
metadata={

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
from typing import Optional
from dbgpt._private.config import Config
from dbgpt.app.base import WebServerParameters
@ -18,9 +19,14 @@ def initialize_components(
system_app: SystemApp,
embedding_model_name: str,
embedding_model_path: str,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
):
# Lazy import to avoid high time cost
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.embedding_component import (
_initialize_embedding_model,
_initialize_rerank_model,
)
from dbgpt.app.initialization.scheduler import DefaultScheduler
from dbgpt.app.initialization.serve_initialization import register_serve_apps
from dbgpt.datasource.manages.connector_manager import ConnectorManager
@ -45,6 +51,7 @@ def initialize_components(
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
_initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path)
_initialize_model_cache(system_app)
_initialize_awel(system_app, param)
# Initialize resource manager of agent
@ -89,6 +96,11 @@ def _initialize_agent(system_app: SystemApp):
def _initialize_resource_manager(system_app: SystemApp):
from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models
from dbgpt.agent.expand.resources.host_tool import (
get_current_host_cpu_status,
get_current_host_memory_status,
get_current_host_system_load,
)
from dbgpt.agent.expand.resources.search_tool import baidu_search
from dbgpt.agent.resource.base import ResourceType
from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource
@ -104,6 +116,10 @@ def _initialize_resource_manager(system_app: SystemApp):
# Register a search tool
rm.register_resource(resource_instance=baidu_search)
rm.register_resource(resource_instance=list_dbgpt_support_models)
# Register host tools
rm.register_resource(resource_instance=get_current_host_cpu_status)
rm.register_resource(resource_instance=get_current_host_memory_status)
rm.register_resource(resource_instance=get_current_host_system_load)
def _initialize_openapi(system_app: SystemApp):

View File

@ -146,11 +146,24 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
embedding_model_name = CFG.EMBEDDING_MODEL
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
rerank_model_name = CFG.RERANK_MODEL
rerank_model_path = None
if rerank_model_name:
rerank_model_path = CFG.RERANK_MODEL_PATH or EMBEDDING_MODEL_CONFIG.get(
rerank_model_name
)
server_init(param, system_app)
mount_routers(app)
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
initialize_components(
param,
system_app,
embedding_model_name,
embedding_model_path,
rerank_model_name,
rerank_model_path,
)
system_app.on_init()
# Migration db storage, so you db models must be imported before this
@ -161,7 +174,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
if not param.light:
print("Model Unified Deployment Mode!")
if not param.remote_embedding:
# Embedding model is running in the same process, set embedding_model_name
# and embedding_model_path to None
embedding_model_name, embedding_model_path = None, None
if not param.remote_rerank:
# Rerank model is running in the same process, set rerank_model_name and
# rerank_model_path to None
rerank_model_name, rerank_model_path = None, None
initialize_worker_manager_in_client(
app=app,
model_name=model_name,
@ -169,6 +188,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
local_port=param.port,
embedding_model_name=embedding_model_name,
embedding_model_path=embedding_model_path,
rerank_model_name=rerank_model_name,
rerank_model_path=rerank_model_path,
start_listener=model_start_listener,
system_app=system_app,
)

View File

@ -1,10 +1,14 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Type
from typing import TYPE_CHECKING, Any, Optional, Type
from dbgpt.component import ComponentType, SystemApp
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
if TYPE_CHECKING:
from langchain.embeddings.base import Embeddings
@ -32,6 +36,26 @@ def _initialize_embedding_model(
)
def _initialize_rerank_model(
param: "WebServerParameters",
system_app: SystemApp,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
):
if not rerank_model_name:
return
if param.remote_rerank:
logger.info("Register remote RemoteRerankEmbeddingFactory")
system_app.register(RemoteRerankEmbeddingFactory, model_name=rerank_model_name)
else:
logger.info(f"Register local LocalRerankEmbeddingFactory")
system_app.register(
LocalRerankEmbeddingFactory,
default_model_name=rerank_model_name,
default_model_path=rerank_model_path,
)
class RemoteEmbeddingFactory(EmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
@ -105,3 +129,81 @@ class LocalEmbeddingFactory(EmbeddingFactory):
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
class RemoteRerankEmbeddingFactory(RerankEmbeddingFactory):
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
super().__init__(system_app=system_app)
self._default_model_name = model_name
self.kwargs = kwargs
self.system_app = system_app
def init_app(self, system_app):
self.system_app = system_app
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "RerankEmbeddings":
from dbgpt.model.cluster import WorkerManagerFactory
from dbgpt.model.cluster.embedding.remote_embedding import (
RemoteRerankEmbeddings,
)
if embedding_cls:
raise NotImplementedError
worker_manager = self.system_app.get_component(
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
).create()
return RemoteRerankEmbeddings(
model_name or self._default_model_name, worker_manager
)
class LocalRerankEmbeddingFactory(RerankEmbeddingFactory):
def __init__(
self,
system_app,
default_model_name: str = None,
default_model_path: str = None,
**kwargs: Any,
) -> None:
super().__init__(system_app=system_app)
self._default_model_name = default_model_name
self._default_model_path = default_model_path
self._kwargs = kwargs
self._model = self._load_model()
def init_app(self, system_app):
pass
def create(
self, model_name: str = None, embedding_cls: Type = None
) -> "RerankEmbeddings":
if embedding_cls:
raise NotImplementedError
return self._model
def _load_model(self) -> "RerankEmbeddings":
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
)
from dbgpt.model.parameter import (
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
)
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
self._default_model_name, EmbeddingModelParameters
)
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
model_name=self._default_model_name,
model_path=self._default_model_path,
param_cls=param_cls,
**self._kwargs,
)
logger.info(model_params)
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load_rerank_model(self._default_model_name, model_params)

View File

@ -18,6 +18,7 @@ from dbgpt.core import (
MessagesPlaceholder,
SystemPromptTemplate,
)
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
from dbgpt.rag.retriever.rewrite import QueryRewrite
from dbgpt.util.tracer import root_tracer, trace
@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
- model_name:(str) llm model name
- select_param:(str) space name
"""
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
from dbgpt.rag.embedding.embedding_factory import (
EmbeddingFactory,
RerankEmbeddingFactory,
)
self.knowledge_space = chat_param["select_param"]
chat_param["chat_mode"] = ChatScene.ChatKnowledge
@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
model_name=self.llm_model,
language=CFG.LANGUAGE,
)
reranker = None
retriever_top_k = self.top_k
if CFG.RERANK_MODEL:
rerank_embeddings = RerankEmbeddingFactory.get_instance(
CFG.SYSTEM_APP
).create()
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
# We use reranker, so if the top_k is less than 20,
# we need to set it to 20
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
self.embedding_retriever = EmbeddingRetriever(
top_k=self.top_k,
top_k=retriever_top_k,
vector_store_connector=vector_store_connector,
query_rewrite=query_rewrite,
rerank=reranker,
)
self.prompt_template.template_is_strict = False
self.relations = None

View File

@ -209,6 +209,11 @@ EMBEDDING_MODEL_CONFIG = {
# Common HTTP embedding model
"proxy_http_openapi": "proxy_http_openapi",
"proxy_ollama": "proxy_ollama",
# Rerank model, rerank mode is a special embedding model
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
# Proxy rerank model
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
}

View File

@ -7,7 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
CachePolicy,
CacheValue,
)
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
from dbgpt.core.interface.embeddings import Embeddings, RerankEmbeddings # noqa: F401
from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401
from dbgpt.core.interface.llm import ( # noqa: F401
DefaultMessageConverter,
@ -106,6 +106,7 @@ __ALL__ = [
"QuerySpec",
"StorageError",
"Embeddings",
"RerankEmbeddings",
"Chunk",
"Document",
]

View File

@ -1,9 +1,24 @@
"""Interface for embedding models."""
import asyncio
from abc import ABC, abstractmethod
from typing import List
class RerankEmbeddings(ABC):
"""Interface for rerank models."""
@abstractmethod
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the scores of the candidates."""
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Asynchronously predict the scores of the candidates."""
return await asyncio.get_running_loop().run_in_executor(
None, self.predict, query, candidates
)
class Embeddings(ABC):
"""Interface for embedding models.

View File

@ -166,6 +166,23 @@ class EmbeddingsResponse(BaseModel):
usage: UsageInfo = Field(..., description="Usage info")
class RelevanceRequest(BaseModel):
"""Relevance request entity."""
model: str = Field(..., description="Rerank model name")
query: str = Field(..., description="Query text")
documents: List[str] = Field(..., description="Document texts")
class RelevanceResponse(BaseModel):
"""Relevance response entity."""
object: str = Field("list", description="Object type")
model: str = Field(..., description="Rerank model name")
data: List[float] = Field(..., description="Data list, relevance scores")
usage: UsageInfo = Field(..., description="Usage info")
class ModelPermission(BaseModel):
"""Model permission entity."""

View File

@ -4,7 +4,7 @@ import logging
from typing import List, Optional, Type, cast
from dbgpt.configs.model_config import get_device
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.parameter import (
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
@ -66,6 +66,38 @@ class EmbeddingLoader:
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
def load_rerank_model(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> RerankEmbeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load_rerank_model",
span_type=SpanType.RUN,
metadata=metadata,
):
if model_name in ["rerank_proxy_http_openapi"]:
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return CrossEncoderRerankEmbeddings(**kwargs)
def _parse_embedding_params(
model_name: Optional[str] = None,

View File

@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
if model_name is None and model_type != ModelType.VLLM:
return None
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]

View File

@ -3,6 +3,7 @@
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
import asyncio
import json
import logging
@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
ModelCard,
ModelList,
ModelPermission,
RelevanceRequest,
RelevanceResponse,
UsageInfo,
)
from dbgpt.model.base import ModelInstance
@ -368,6 +371,28 @@ class APIServer(BaseComponent):
}
return await worker_manager.embeddings(params)
async def relevance_generate(
self, model: str, query: str, texts: List[str]
) -> List[float]:
"""Generate embeddings
Args:
model (str): Model name
query (str): Query text
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
"query": query,
}
scores = await worker_manager.embeddings(params)
return scores[0]
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
@ -456,6 +481,26 @@ async def create_embeddings(
)
@router.post(
"/v1/beta/relevance",
dependencies=[Depends(check_api_key)],
response_model=RelevanceResponse,
)
async def create_embeddings(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
scores = await api_server.relevance_generate(
request.model, request.query, request.documents
)
return model_to_dict(
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
exclude_none=True,
)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage
@ -31,7 +31,9 @@ class PromptRequest(BaseModel):
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
span_id: str = None
span_id: Optional[str] = None
query: Optional[str] = None
"""For rerank model, query is required"""
class CountTokenRequest(BaseModel):

View File

@ -1,6 +1,6 @@
from typing import List
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.cluster.manager_base import WorkerManager
@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self.aembed_documents([text])[0]
class RemoteRerankEmbeddings(RerankEmbeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
return self.worker_manager.sync_embeddings(params)[0]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Asynchronously predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
# Use embeddings interface to get scores of ranker
scores = await self.worker_manager.embeddings(params)
# The first element is the scores of the query
return scores[0]

View File

@ -1,7 +1,7 @@
import logging
from typing import Dict, List, Type
from typing import Dict, List, Type, Union
from dbgpt.core import ModelMetadata
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
@ -20,13 +20,12 @@ logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
def __init__(self) -> None:
from dbgpt.rag.embedding import Embeddings
self._embeddings_impl: Embeddings = None
def __init__(self, rerank_model: bool = False) -> None:
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
self._model_params = None
self.model_name = None
self.model_path = None
self._rerank_model = rerank_model
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
@ -64,8 +63,17 @@ class EmbeddingsModelWorker(ModelWorker):
"""Start model worker"""
if not model_params:
model_params = self.parse_parameters(command_args)
if self._rerank_model:
model_params.rerank = True # type: ignore
self._model_params = model_params
self._embeddings_impl = self._loader.load(self.model_name, model_params)
if model_params.is_rerank_model():
logger.info(f"Load rerank embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load_rerank_model(
self.model_name, model_params
)
else:
logger.info(f"Load embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load(self.model_name, model_params)
def __del__(self):
self.stop()
@ -96,5 +104,10 @@ class EmbeddingsModelWorker(ModelWorker):
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
return self._embeddings_impl.embed_documents(input)
textx: List[str] = params["input"]
if isinstance(self._embeddings_impl, RerankEmbeddings):
query = params["query"]
scores: List[float] = self._embeddings_impl.predict(query, textx)
return [scores]
else:
return self._embeddings_impl.embed_documents(textx)

View File

@ -952,7 +952,10 @@ def _create_local_model_manager(
)
def _build_worker(worker_params: ModelWorkerParameters):
def _build_worker(
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
worker_class = worker_params.worker_class
if worker_class:
from dbgpt.util.module_utils import import_from_checked_string
@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
else:
raise Exception("Unsupported worker type: {worker_params.worker_type}")
return worker_cls()
if ext_worker_kwargs:
return worker_cls(**ext_worker_kwargs)
else:
return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
with root_tracer.start_span(
"WorkerManager._start_local_worker",
@ -991,7 +999,7 @@ def _start_local_worker(
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
worker = _build_worker(worker_params)
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
worker_manager: WorkerManagerAdapter,
embedding_model_name: str = None,
embedding_model_path: str = None,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
if not embedding_model_name or not embedding_model_path:
return
@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
logger.info(
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
)
_start_local_worker(worker_manager, embedding_worker_params)
_start_local_worker(
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
run_locally: bool = True,
controller_addr: str = None,
controller_addr: Optional[str] = None,
local_port: int = 5670,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
system_app: SystemApp = None,
embedding_model_name: Optional[str] = None,
embedding_model_path: Optional[str] = None,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
system_app: Optional[SystemApp] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
_start_local_embedding_worker(
worker_manager,
rerank_model_name,
rerank_model_path,
ext_worker_kwargs={"rerank_model": True},
)
else:
from dbgpt.model.cluster.controller.controller import (
ModelRegistryClient,
@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
controller_addr = worker_params.controller_addr
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)

View File

@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return False
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
max_length: Optional[int] = field(
default=None,
metadata={
"help": "Max length for input sequences. Longer sequences will be "
"truncated. If None, max length of the model will be used, just for rerank"
" model now."
},
)
def build_kwargs(self, **kwargs) -> Dict:
model_kwargs, encode_kwargs = None, None
if self.device:
@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
if model_kwargs:
kwargs["model_kwargs"] = model_kwargs
if self.is_rerank_model():
kwargs["max_length"] = self.max_length
if encode_kwargs:
kwargs["encode_kwargs"] = encode_kwargs
return kwargs
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
@dataclass
class ModelParameters(BaseModelParameters):
@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"openai_api_version": (
self.proxy_api_version if self.proxy_api_version else None
),
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
"deployment": (
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
),
}
for k, v in kwargs:
params[k] = v
return params
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
"proxy_ollama,rerank_proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}

View File

@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
OllamaEmbeddings,
OpenAPIEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
__ALL__ = [
"Embeddings",
@ -28,4 +29,6 @@ __ALL__ = [
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
]

View File

@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.core.awel import DAGVar
from dbgpt.core.awel.flow import ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
"""
class RerankEmbeddingFactory(BaseComponent, ABC):
"""Class for RerankEmbeddingFactory."""
name = "rerank_embedding_factory"
@abstractmethod
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> RerankEmbeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
Returns:
RerankEmbeddings: The embedding instance.
"""
class DefaultEmbeddingFactory(EmbeddingFactory):
"""The default embedding factory."""

View File

@ -1,6 +1,5 @@
"""Embedding implementations."""
from typing import Any, Dict, List, Optional
import aiohttp

View File

@ -0,0 +1,131 @@
"""Re-rank embeddings."""
from typing import Any, Dict, List, Optional, cast
import aiohttp
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import RerankEmbeddings
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
"""CrossEncoder Rerank Embeddings."""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = "BAAI/bge-reranker-base"
max_length: Optional[int] = None
"""Max length for input sequences. Longer sequences will be truncated. If None, max
length of the model will be used"""
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
kwargs["client"] = CrossEncoder(
kwargs.get("model_name"),
max_length=kwargs.get("max_length"),
**kwargs.get("model_kwargs"),
)
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
from sentence_transformers import CrossEncoder
query_content_pairs = [[query, candidate] for candidate in candidates]
_model = cast(CrossEncoder, self.client)
rank_scores = _model.predict(sentences=query_content_pairs)
return rank_scores.tolist()
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
"""OpenAPI Rerank Embeddings."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/v1/beta/relevance",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="bge-reranker-base", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: Optional[requests.Session] = None
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
if not candidates:
return []
data = {"model": self.model_name, "query": query, "documents": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout
)
response.raise_for_status()
return response.json()["data"]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
data = {"model": self.model_name, "query": query, "documents": candidates}
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]

View File

@ -1,4 +1,5 @@
"""Embedding retriever."""
from functools import reduce
from typing import Any, Dict, List, Optional, cast
@ -207,7 +208,7 @@ class EmbeddingRetriever(BaseRetriever):
"rerank_cls": self._rerank.__class__.__name__,
},
):
new_candidates_with_score = self._rerank.rank(
new_candidates_with_score = await self._rerank.arank(
new_candidates_with_score, query
)
return new_candidates_with_score

View File

@ -1,9 +1,10 @@
"""Rerank module for RAG retriever."""
import asyncio
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from dbgpt.core import Chunk
from dbgpt.core import Chunk, RerankEmbeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@ -39,6 +40,24 @@ class Ranker(ABC):
List[Chunk]
"""
async def arank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Rank algorithm implementation return topk documents by candidates
similarity score
Args:
candidates_with_scores: List[Tuple]
query: Optional[str]
Return:
List[Chunk]
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.rank, candidates_with_scores, query
)
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
"""Filter duplicate candidates documents."""
candidates_with_scores = sorted(
@ -52,6 +71,18 @@ class Ranker(ABC):
visited_docs.add(candidate_chunk.content)
return new_candidates
def _rerank_with_scores(
self, candidates_with_scores: List[Chunk], rank_scores: List[float]
) -> List[Chunk]:
"""Rerank candidates with scores."""
for candidate, score in zip(candidates_with_scores, rank_scores):
candidate.score = float(score)
new_candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores
@register_resource(
_("Default Ranker"),
@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker):
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores[: self.topk]
class RerankEmbeddingsRanker(Ranker):
"""Rerank Embeddings Ranker."""
def __init__(
self,
rerank_embeddings: RerankEmbeddings,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""Rerank Embeddings rank algorithm implementation."""
self._model = rerank_embeddings
super().__init__(topk, rank_fn)
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Rerank Embeddings rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
if not candidates_with_scores or not query:
return candidates_with_scores
contents = [candidate.content for candidate in candidates_with_scores]
rank_scores = self._model.predict(query, contents)
new_candidates_with_scores = self._rerank_with_scores(
candidates_with_scores, rank_scores
)
return new_candidates_with_scores[: self.topk]
async def arank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Rerank Embeddings rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
if not candidates_with_scores or not query:
return candidates_with_scores
contents = [candidate.content for candidate in candidates_with_scores]
rank_scores = await self._model.apredict(query, contents)
new_candidates_with_scores = self._rerank_with_scores(
candidates_with_scores, rank_scores
)
return new_candidates_with_scores[: self.topk]

View File

@ -55,7 +55,6 @@ curl http://127.0.0.1:8100/api/v1/embeddings \
```
## Verify via OpenAI SDK
```bash
@ -72,3 +71,49 @@ completion = openai.ChatCompletion.create(
print(completion.choices[0].message.content)
```
## (Experimental) Rerank Open API
The rerank API is an experimental feature that can be used to rerank the candidate list.
1. Use cURL to verify the rerank API.
```bash
curl http://127.0.0.1:8100/api/v1/beta/relevance \
-H "Authorization: Bearer EMPTY" \
-H "Content-Type: application/json" \
-d '{
"model": "bge-reranker-base",
"query": "what is awel talk about?",
"documents": [
"Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.",
"Autonomous agents have long been a research focus in academic and industry communities",
"AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.",
"Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla."
]
}'
```
2. Use python to verify the rerank API.
```python
from dbgpt.rag.embedding import OpenAPIRerankEmbeddings
rerank = OpenAPIRerankEmbeddings(api_key="EMPTY", model_name="bge-reranker-base")
rerank.predict(
query="what is awel talk about?",
candidates=[
"Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.",
"Autonomous agents have long been a research focus in academic and industry communities",
"AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.",
"Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla."
]
)
```
The output is as follows:
```bash
[
0.9685816764831543,
3.7338297261158004e-05,
0.03692878410220146,
3.73825132555794e-05
]
```

View File

@ -3,21 +3,21 @@
## Install command line tools
All the following operations are completed through the `dbgpt` command. To use the `dbgpt` command, you first need to install the `DB-GPT` project. You can install it through the following command
```python
```shell
$ pip install -e ".[default]"
```
It can also be used in script mode
```python
```shell
$ python pilot/scripts/cli_scripts.py
```
## Start Model Controller
```python
```shell
$ dbgpt start controller
```
## View log
```python
```shell
$ docker logs db-gpt-webserver-1 -f
```
By default, `Model Server` will start on port `8000`
@ -28,7 +28,7 @@ By default, `Model Server` will start on port `8000`
Start `chatglm2-6b` model Worker
:::
```python
```shell
dbgpt start worker --model_name chatglm2-6b \
--model_path /app/models/chatglm2-6b \
--port 8001 \
@ -40,7 +40,7 @@ dbgpt start worker --model_name chatglm2-6b \
Start `vicuna-13b-v1.5` model Worker
:::
```python
```shell
dbgpt start worker --model_name vicuna-13b-v1.5 \
--model_path /app/models/vicuna-13b-v1.5 \
--port 8002 \
@ -52,9 +52,9 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \
:::
## Start the embedding model service
## Start Embedding Model Worker
```python
```shell
dbgpt start worker --model_name text2vec \
--model_path /app/models/text2vec-large-chinese \
--worker_type text2vec \
@ -66,24 +66,41 @@ dbgpt start worker --model_name text2vec \
:::
## Start Reranking Model Worker
```shell
dbgpt start worker --worker_type text2vec \
--rerank \
--model_path /app/models/bge-reranker-base \
--model_name bge-reranker-base \
--port 8004 \
--controller_addr http://127.0.0.1:8000
```
:::info note
⚠️ Make sure to use your own model name and model path.
:::
:::tip
View and inspect deployed models
:::
```python
```shell
$ dbgpt model list
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
| WorkerManager | service | 172.21.0.5 | 8004 | True | True | | 2023-09-12T23:04:24.598439 |
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
| bge-reranker-base | text2vec | 172.21.0.5 | 8004 | True | True | | 2024-05-15T11:36:12.935012 |
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
```
@ -91,13 +108,13 @@ $ dbgpt model list
The model service deployed as above can be used through dbgpt_server. First modify the `.env` configuration file to change the connection model address
```python
```shell
dbgpt start webserver --light
```
## Start Webserver
```python
```shell
LLM_MODEL=vicuna-13b-v1.5
# The current default MODEL_SERVER address is the address of the Model Controller
MODEL_SERVER=http://127.0.0.1:8000
@ -106,7 +123,7 @@ MODEL_SERVER=http://127.0.0.1:8000
Or it can be started directly by command to formulate the model.
```python
```shell
LLM_MODEL=chatglm2-6b dbgpt start webserver --light
```
@ -118,7 +135,7 @@ For more information about the use of the command line, you can view the command
View dbgpt help `dbgpt --help`
:::
```python
```shell
dbgpt --help
Already connect 'dbgpt'
@ -143,7 +160,7 @@ Commands:
Check the dbgpt start command `dbgpt start --help`
:::
```python
```shell
dbgpt start --help
Already connect 'dbgpt'
@ -166,7 +183,7 @@ Commands:
View the dbgpt start model service help command `dbgpt start worker --help`
:::
```python
```shell
dbgpt start worker --help
Already connect 'dbgpt'
@ -239,7 +256,7 @@ Options:
View dbgpt model service related commands `dbgpt model --help`
:::
```python
```shell
dbgpt model --help

View File

@ -13,4 +13,5 @@ types-requests
types-beautifulsoup4
types-Markdown
types-tqdm
pandas-stubs
pandas-stubs
types-psutil