mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 08:11:45 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
parent
559affe87d
commit
593e974405
@ -97,6 +97,23 @@ KNOWLEDGE_SEARCH_REWRITE=False
|
||||
# proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924
|
||||
# proxy_http_openapi_proxy_backend=text2vec
|
||||
|
||||
#*******************************************************************#
|
||||
#** RERANK SETTINGS **#
|
||||
#*******************************************************************#
|
||||
## Rerank model
|
||||
# RERANK_MODEL=bge-reranker-base
|
||||
## If you not set RERANK_MODEL_PATH, DB-GPT will read the model path from EMBEDDING_MODEL_CONFIG based on the RERANK_MODEL.
|
||||
# RERANK_MODEL_PATH=
|
||||
## The number of rerank results to return
|
||||
# RERANK_TOP_K=3
|
||||
|
||||
## Common HTTP rerank model
|
||||
# RERANK_MODEL=rerank_proxy_http_openapi
|
||||
# rerank_proxy_http_openapi_proxy_server_url=http://127.0.0.1:8100/api/v1/beta/relevance
|
||||
# rerank_proxy_http_openapi_proxy_api_key={your-api-key}
|
||||
# rerank_proxy_http_openapi_proxy_backend=bge-reranker-base
|
||||
|
||||
|
||||
|
||||
|
||||
#*******************************************************************#
|
||||
|
@ -247,6 +247,10 @@ class Config(metaclass=Singleton):
|
||||
|
||||
### EMBEDDING Configuration
|
||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||
# Rerank model configuration
|
||||
self.RERANK_MODEL = os.getenv("RERANK_MODEL")
|
||||
self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH")
|
||||
self.RERANK_TOP_K = int(os.getenv("RERANK_TOP_K", 3))
|
||||
self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100))
|
||||
self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50))
|
||||
self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5))
|
||||
|
45
dbgpt/agent/expand/resources/host_tool.py
Normal file
45
dbgpt/agent/expand/resources/host_tool.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""Host tool resource module."""
|
||||
|
||||
from ...resource.tool.base import tool
|
||||
|
||||
|
||||
@tool(description="Get current host CPU status.")
|
||||
def get_current_host_cpu_status() -> str:
|
||||
"""Get current host CPU status."""
|
||||
import platform
|
||||
|
||||
import psutil
|
||||
|
||||
cpu_architecture = platform.machine()
|
||||
cpu_count_physical = psutil.cpu_count(logical=False)
|
||||
cpu_count_logical = psutil.cpu_count(logical=True)
|
||||
cpu_usage = psutil.cpu_percent(interval=1)
|
||||
return (
|
||||
f"CPU Architecture: {cpu_architecture}\n"
|
||||
f"Physical CPU Cores: {cpu_count_physical}\n"
|
||||
f"Logical CPU Cores: {cpu_count_logical}\n"
|
||||
f"CPU Usage: {cpu_usage}%"
|
||||
)
|
||||
|
||||
|
||||
@tool(description="Get current host memory status.")
|
||||
def get_current_host_memory_status() -> str:
|
||||
"""Get current host memory status."""
|
||||
import psutil
|
||||
|
||||
memory = psutil.virtual_memory()
|
||||
return (
|
||||
f"Total: {memory.total / (1024**3):.2f} GB\n"
|
||||
f"Available: {memory.available / (1024**3):.2f} GB\n"
|
||||
f"Used: {memory.used / (1024**3):.2f} GB\n"
|
||||
f"Percent: {memory.percent}%"
|
||||
)
|
||||
|
||||
|
||||
@tool(description="Get current host system load.")
|
||||
def get_current_host_system_load() -> str:
|
||||
"""Get current host system load."""
|
||||
import os
|
||||
|
||||
load1, load5, load15 = os.getloadavg()
|
||||
return f"System load average: {load1:.2f}, {load5:.2f}, {load15:.2f}"
|
@ -227,6 +227,14 @@ class WebServerParameters(BaseParameters):
|
||||
"text2vec --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
remote_rerank: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to enable remote rerank models. If it is True, you need"
|
||||
" to start a rerank model through `dbgpt start worker --worker_type "
|
||||
"text2vec --rerank --model_name xxx --model_path xxx`"
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.app.base import WebServerParameters
|
||||
@ -18,9 +19,14 @@ def initialize_components(
|
||||
system_app: SystemApp,
|
||||
embedding_model_name: str,
|
||||
embedding_model_path: str,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
):
|
||||
# Lazy import to avoid high time cost
|
||||
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
|
||||
from dbgpt.app.initialization.embedding_component import (
|
||||
_initialize_embedding_model,
|
||||
_initialize_rerank_model,
|
||||
)
|
||||
from dbgpt.app.initialization.scheduler import DefaultScheduler
|
||||
from dbgpt.app.initialization.serve_initialization import register_serve_apps
|
||||
from dbgpt.datasource.manages.connector_manager import ConnectorManager
|
||||
@ -45,6 +51,7 @@ def initialize_components(
|
||||
_initialize_embedding_model(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path)
|
||||
_initialize_model_cache(system_app)
|
||||
_initialize_awel(system_app, param)
|
||||
# Initialize resource manager of agent
|
||||
@ -89,6 +96,11 @@ def _initialize_agent(system_app: SystemApp):
|
||||
|
||||
def _initialize_resource_manager(system_app: SystemApp):
|
||||
from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models
|
||||
from dbgpt.agent.expand.resources.host_tool import (
|
||||
get_current_host_cpu_status,
|
||||
get_current_host_memory_status,
|
||||
get_current_host_system_load,
|
||||
)
|
||||
from dbgpt.agent.expand.resources.search_tool import baidu_search
|
||||
from dbgpt.agent.resource.base import ResourceType
|
||||
from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource
|
||||
@ -104,6 +116,10 @@ def _initialize_resource_manager(system_app: SystemApp):
|
||||
# Register a search tool
|
||||
rm.register_resource(resource_instance=baidu_search)
|
||||
rm.register_resource(resource_instance=list_dbgpt_support_models)
|
||||
# Register host tools
|
||||
rm.register_resource(resource_instance=get_current_host_cpu_status)
|
||||
rm.register_resource(resource_instance=get_current_host_memory_status)
|
||||
rm.register_resource(resource_instance=get_current_host_system_load)
|
||||
|
||||
|
||||
def _initialize_openapi(system_app: SystemApp):
|
||||
|
@ -146,11 +146,24 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
|
||||
embedding_model_name = CFG.EMBEDDING_MODEL
|
||||
embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL]
|
||||
rerank_model_name = CFG.RERANK_MODEL
|
||||
rerank_model_path = None
|
||||
if rerank_model_name:
|
||||
rerank_model_path = CFG.RERANK_MODEL_PATH or EMBEDDING_MODEL_CONFIG.get(
|
||||
rerank_model_name
|
||||
)
|
||||
|
||||
server_init(param, system_app)
|
||||
mount_routers(app)
|
||||
model_start_listener = _create_model_start_listener(system_app)
|
||||
initialize_components(param, system_app, embedding_model_name, embedding_model_path)
|
||||
initialize_components(
|
||||
param,
|
||||
system_app,
|
||||
embedding_model_name,
|
||||
embedding_model_path,
|
||||
rerank_model_name,
|
||||
rerank_model_path,
|
||||
)
|
||||
system_app.on_init()
|
||||
|
||||
# Migration db storage, so you db models must be imported before this
|
||||
@ -161,7 +174,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
if not param.remote_embedding:
|
||||
# Embedding model is running in the same process, set embedding_model_name
|
||||
# and embedding_model_path to None
|
||||
embedding_model_name, embedding_model_path = None, None
|
||||
if not param.remote_rerank:
|
||||
# Rerank model is running in the same process, set rerank_model_name and
|
||||
# rerank_model_path to None
|
||||
rerank_model_name, rerank_model_path = None, None
|
||||
initialize_worker_manager_in_client(
|
||||
app=app,
|
||||
model_name=model_name,
|
||||
@ -169,6 +188,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
|
||||
local_port=param.port,
|
||||
embedding_model_name=embedding_model_name,
|
||||
embedding_model_path=embedding_model_path,
|
||||
rerank_model_name=rerank_model_name,
|
||||
rerank_model_path=rerank_model_path,
|
||||
start_listener=model_start_listener,
|
||||
system_app=system_app,
|
||||
)
|
||||
|
@ -1,10 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
|
||||
from dbgpt.component import ComponentType, SystemApp
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
RerankEmbeddingFactory,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@ -32,6 +36,26 @@ def _initialize_embedding_model(
|
||||
)
|
||||
|
||||
|
||||
def _initialize_rerank_model(
|
||||
param: "WebServerParameters",
|
||||
system_app: SystemApp,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
):
|
||||
if not rerank_model_name:
|
||||
return
|
||||
if param.remote_rerank:
|
||||
logger.info("Register remote RemoteRerankEmbeddingFactory")
|
||||
system_app.register(RemoteRerankEmbeddingFactory, model_name=rerank_model_name)
|
||||
else:
|
||||
logger.info(f"Register local LocalRerankEmbeddingFactory")
|
||||
system_app.register(
|
||||
LocalRerankEmbeddingFactory,
|
||||
default_model_name=rerank_model_name,
|
||||
default_model_path=rerank_model_path,
|
||||
)
|
||||
|
||||
|
||||
class RemoteEmbeddingFactory(EmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
@ -105,3 +129,81 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
||||
|
||||
|
||||
class RemoteRerankEmbeddingFactory(RerankEmbeddingFactory):
|
||||
def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = model_name
|
||||
self.kwargs = kwargs
|
||||
self.system_app = system_app
|
||||
|
||||
def init_app(self, system_app):
|
||||
self.system_app = system_app
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "RerankEmbeddings":
|
||||
from dbgpt.model.cluster import WorkerManagerFactory
|
||||
from dbgpt.model.cluster.embedding.remote_embedding import (
|
||||
RemoteRerankEmbeddings,
|
||||
)
|
||||
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
worker_manager = self.system_app.get_component(
|
||||
ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory
|
||||
).create()
|
||||
return RemoteRerankEmbeddings(
|
||||
model_name or self._default_model_name, worker_manager
|
||||
)
|
||||
|
||||
|
||||
class LocalRerankEmbeddingFactory(RerankEmbeddingFactory):
|
||||
def __init__(
|
||||
self,
|
||||
system_app,
|
||||
default_model_name: str = None,
|
||||
default_model_path: str = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(system_app=system_app)
|
||||
self._default_model_name = default_model_name
|
||||
self._default_model_path = default_model_path
|
||||
self._kwargs = kwargs
|
||||
self._model = self._load_model()
|
||||
|
||||
def init_app(self, system_app):
|
||||
pass
|
||||
|
||||
def create(
|
||||
self, model_name: str = None, embedding_cls: Type = None
|
||||
) -> "RerankEmbeddings":
|
||||
if embedding_cls:
|
||||
raise NotImplementedError
|
||||
return self._model
|
||||
|
||||
def _load_model(self) -> "RerankEmbeddings":
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
)
|
||||
from dbgpt.model.parameter import (
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
)
|
||||
|
||||
param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
self._default_model_name, EmbeddingModelParameters
|
||||
)
|
||||
model_params: BaseEmbeddingModelParameters = _parse_embedding_params(
|
||||
model_name=self._default_model_name,
|
||||
model_path=self._default_model_path,
|
||||
param_cls=param_cls,
|
||||
**self._kwargs,
|
||||
)
|
||||
logger.info(model_params)
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load_rerank_model(self._default_model_name, model_params)
|
||||
|
@ -18,6 +18,7 @@ from dbgpt.core import (
|
||||
MessagesPlaceholder,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker
|
||||
from dbgpt.rag.retriever.rewrite import QueryRewrite
|
||||
from dbgpt.util.tracer import root_tracer, trace
|
||||
|
||||
@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat):
|
||||
- model_name:(str) llm model name
|
||||
- select_param:(str) space name
|
||||
"""
|
||||
from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory
|
||||
from dbgpt.rag.embedding.embedding_factory import (
|
||||
EmbeddingFactory,
|
||||
RerankEmbeddingFactory,
|
||||
)
|
||||
|
||||
self.knowledge_space = chat_param["select_param"]
|
||||
chat_param["chat_mode"] = ChatScene.ChatKnowledge
|
||||
@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat):
|
||||
model_name=self.llm_model,
|
||||
language=CFG.LANGUAGE,
|
||||
)
|
||||
reranker = None
|
||||
retriever_top_k = self.top_k
|
||||
if CFG.RERANK_MODEL:
|
||||
rerank_embeddings = RerankEmbeddingFactory.get_instance(
|
||||
CFG.SYSTEM_APP
|
||||
).create()
|
||||
reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K)
|
||||
if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20:
|
||||
# We use reranker, so if the top_k is less than 20,
|
||||
# we need to set it to 20
|
||||
retriever_top_k = max(CFG.RERANK_TOP_K, 20)
|
||||
self.embedding_retriever = EmbeddingRetriever(
|
||||
top_k=self.top_k,
|
||||
top_k=retriever_top_k,
|
||||
vector_store_connector=vector_store_connector,
|
||||
query_rewrite=query_rewrite,
|
||||
rerank=reranker,
|
||||
)
|
||||
self.prompt_template.template_is_strict = False
|
||||
self.relations = None
|
||||
|
@ -209,6 +209,11 @@ EMBEDDING_MODEL_CONFIG = {
|
||||
# Common HTTP embedding model
|
||||
"proxy_http_openapi": "proxy_http_openapi",
|
||||
"proxy_ollama": "proxy_ollama",
|
||||
# Rerank model, rerank mode is a special embedding model
|
||||
"bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"),
|
||||
"bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"),
|
||||
# Proxy rerank model
|
||||
"rerank_proxy_http_openapi": "rerank_proxy_http_openapi",
|
||||
}
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401
|
||||
CachePolicy,
|
||||
CacheValue,
|
||||
)
|
||||
from dbgpt.core.interface.embeddings import Embeddings # noqa: F401
|
||||
from dbgpt.core.interface.embeddings import Embeddings, RerankEmbeddings # noqa: F401
|
||||
from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401
|
||||
from dbgpt.core.interface.llm import ( # noqa: F401
|
||||
DefaultMessageConverter,
|
||||
@ -106,6 +106,7 @@ __ALL__ = [
|
||||
"QuerySpec",
|
||||
"StorageError",
|
||||
"Embeddings",
|
||||
"RerankEmbeddings",
|
||||
"Chunk",
|
||||
"Document",
|
||||
]
|
||||
|
@ -1,9 +1,24 @@
|
||||
"""Interface for embedding models."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class RerankEmbeddings(ABC):
|
||||
"""Interface for rerank models."""
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the scores of the candidates."""
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Asynchronously predict the scores of the candidates."""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.predict, query, candidates
|
||||
)
|
||||
|
||||
|
||||
class Embeddings(ABC):
|
||||
"""Interface for embedding models.
|
||||
|
||||
|
@ -166,6 +166,23 @@ class EmbeddingsResponse(BaseModel):
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class RelevanceRequest(BaseModel):
|
||||
"""Relevance request entity."""
|
||||
|
||||
model: str = Field(..., description="Rerank model name")
|
||||
query: str = Field(..., description="Query text")
|
||||
documents: List[str] = Field(..., description="Document texts")
|
||||
|
||||
|
||||
class RelevanceResponse(BaseModel):
|
||||
"""Relevance response entity."""
|
||||
|
||||
object: str = Field("list", description="Object type")
|
||||
model: str = Field(..., description="Rerank model name")
|
||||
data: List[float] = Field(..., description="Data list, relevance scores")
|
||||
usage: UsageInfo = Field(..., description="Usage info")
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
"""Model permission entity."""
|
||||
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
from typing import List, Optional, Type, cast
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.parameter import (
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
@ -66,6 +66,38 @@ class EmbeddingLoader:
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
def load_rerank_model(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> RerankEmbeddings:
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
"params": _get_dict_from_obj(param),
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingLoader.load_rerank_model",
|
||||
span_type=SpanType.RUN,
|
||||
metadata=metadata,
|
||||
):
|
||||
if model_name in ["rerank_proxy_http_openapi"]:
|
||||
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
openapi_param = {}
|
||||
if proxy_param.proxy_server_url:
|
||||
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||
if proxy_param.proxy_api_key:
|
||||
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||
else:
|
||||
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return CrossEncoderRerankEmbeddings(**kwargs)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: Optional[str] = None,
|
||||
|
@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
model_type = pre_args.get("model_type")
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
RelevanceRequest,
|
||||
RelevanceResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from dbgpt.model.base import ModelInstance
|
||||
@ -368,6 +371,28 @@ class APIServer(BaseComponent):
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
async def relevance_generate(
|
||||
self, model: str, query: str, texts: List[str]
|
||||
) -> List[float]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
query (str): Query text
|
||||
texts (List[str]): Texts to embed
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
"query": query,
|
||||
}
|
||||
scores = await worker_manager.embeddings(params)
|
||||
return scores[0]
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
@ -456,6 +481,26 @@ async def create_embeddings(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/beta/relevance",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=RelevanceResponse,
|
||||
)
|
||||
async def create_embeddings(
|
||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
"""Generate relevance scores for a query and a list of documents."""
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
|
||||
scores = await api_server.relevance_generate(
|
||||
request.model, request.query, request.documents
|
||||
)
|
||||
return model_to_dict(
|
||||
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
@ -31,7 +31,9 @@ class PromptRequest(BaseModel):
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
span_id: str = None
|
||||
span_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
"""For rerank model, query is required"""
|
||||
|
||||
|
||||
class CountTokenRequest(BaseModel):
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self.aembed_documents([text])[0]
|
||||
|
||||
|
||||
class RemoteRerankEmbeddings(RerankEmbeddings):
|
||||
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
||||
self.model_name = model_name
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
return self.worker_manager.sync_embeddings(params)[0]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Asynchronously predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
# Use embeddings interface to get scores of ranker
|
||||
scores = await self.worker_manager.embeddings(params)
|
||||
# The first element is the scores of the query
|
||||
return scores[0]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
@ -20,13 +20,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
from dbgpt.rag.embedding import Embeddings
|
||||
|
||||
self._embeddings_impl: Embeddings = None
|
||||
def __init__(self, rerank_model: bool = False) -> None:
|
||||
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
|
||||
self._model_params = None
|
||||
self.model_name = None
|
||||
self.model_path = None
|
||||
self._rerank_model = rerank_model
|
||||
self._loader = EmbeddingLoader()
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
@ -64,8 +63,17 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
"""Start model worker"""
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
if self._rerank_model:
|
||||
model_params.rerank = True # type: ignore
|
||||
self._model_params = model_params
|
||||
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||
if model_params.is_rerank_model():
|
||||
logger.info(f"Load rerank embeddings model: {self.model_name}")
|
||||
self._embeddings_impl = self._loader.load_rerank_model(
|
||||
self.model_name, model_params
|
||||
)
|
||||
else:
|
||||
logger.info(f"Load embeddings model: {self.model_name}")
|
||||
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
@ -96,5 +104,10 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
model = params.get("model")
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
input: List[str] = params["input"]
|
||||
return self._embeddings_impl.embed_documents(input)
|
||||
textx: List[str] = params["input"]
|
||||
if isinstance(self._embeddings_impl, RerankEmbeddings):
|
||||
query = params["query"]
|
||||
scores: List[float] = self._embeddings_impl.predict(query, textx)
|
||||
return [scores]
|
||||
else:
|
||||
return self._embeddings_impl.embed_documents(textx)
|
||||
|
@ -952,7 +952,10 @@ def _create_local_model_manager(
|
||||
)
|
||||
|
||||
|
||||
def _build_worker(worker_params: ModelWorkerParameters):
|
||||
def _build_worker(
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
worker_class = worker_params.worker_class
|
||||
if worker_class:
|
||||
from dbgpt.util.module_utils import import_from_checked_string
|
||||
@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
|
||||
else:
|
||||
raise Exception("Unsupported worker type: {worker_params.worker_type}")
|
||||
|
||||
return worker_cls()
|
||||
if ext_worker_kwargs:
|
||||
return worker_cls(**ext_worker_kwargs)
|
||||
else:
|
||||
return worker_cls()
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager._start_local_worker",
|
||||
@ -991,7 +999,7 @@ def _start_local_worker(
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
},
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not embedding_model_name or not embedding_model_path:
|
||||
return
|
||||
@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
|
||||
logger.info(
|
||||
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
|
||||
)
|
||||
_start_local_worker(worker_manager, embedding_worker_params)
|
||||
_start_local_worker(
|
||||
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
|
||||
)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
controller_addr: Optional[str] = None,
|
||||
local_port: int = 5670,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
start_listener: Callable[["WorkerManager"], None] = None,
|
||||
system_app: SystemApp = None,
|
||||
embedding_model_name: Optional[str] = None,
|
||||
embedding_model_path: Optional[str] = None,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
):
|
||||
"""Initialize WorkerManager in client.
|
||||
If run_locally is True:
|
||||
@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager,
|
||||
rerank_model_name,
|
||||
rerank_model_path,
|
||||
ext_worker_kwargs={"rerank_model": True},
|
||||
)
|
||||
else:
|
||||
from dbgpt.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
controller_addr = worker_params.controller_addr
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
|
@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
pass
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
},
|
||||
)
|
||||
|
||||
rerank: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether the model is a rerank model"}
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Max length for input sequences. Longer sequences will be "
|
||||
"truncated. If None, max length of the model will be used, just for rerank"
|
||||
" model now."
|
||||
},
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
model_kwargs, encode_kwargs = None, None
|
||||
if self.device:
|
||||
@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
|
||||
if model_kwargs:
|
||||
kwargs["model_kwargs"] = model_kwargs
|
||||
if self.is_rerank_model():
|
||||
kwargs["max_length"] = self.max_length
|
||||
if encode_kwargs:
|
||||
kwargs["encode_kwargs"] = encode_kwargs
|
||||
return kwargs
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return self.rerank
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseModelParameters):
|
||||
@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
|
||||
)
|
||||
|
||||
rerank: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether the model is a rerank model"}
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
params = {
|
||||
"openai_api_base": self.proxy_server_url,
|
||||
"openai_api_key": self.proxy_api_key,
|
||||
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
|
||||
"openai_api_version": self.proxy_api_version
|
||||
if self.proxy_api_version
|
||||
else None,
|
||||
"openai_api_version": (
|
||||
self.proxy_api_version if self.proxy_api_version else None
|
||||
),
|
||||
"model": self.proxy_backend,
|
||||
"deployment": self.proxy_deployment
|
||||
if self.proxy_deployment
|
||||
else self.proxy_backend,
|
||||
"deployment": (
|
||||
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
|
||||
),
|
||||
}
|
||||
for k, v in kwargs:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return self.rerank
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
|
||||
"proxy_ollama,rerank_proxy_http_openapi",
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Embeddings",
|
||||
@ -28,4 +29,6 @@ __ALL__ = [
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.core.awel import DAGVar
|
||||
from dbgpt.core.awel.flow import ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
"""
|
||||
|
||||
|
||||
class RerankEmbeddingFactory(BaseComponent, ABC):
|
||||
"""Class for RerankEmbeddingFactory."""
|
||||
|
||||
name = "rerank_embedding_factory"
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> RerankEmbeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
|
||||
Returns:
|
||||
RerankEmbeddings: The embedding instance.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
"""Embedding implementations."""
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
131
dbgpt/rag/embedding/rerank.py
Normal file
131
dbgpt/rag/embedding/rerank.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Re-rank embeddings."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import RerankEmbeddings
|
||||
|
||||
|
||||
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""CrossEncoder Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "BAAI/bge-reranker-base"
|
||||
max_length: Optional[int] = None
|
||||
"""Max length for input sequences. Longer sequences will be truncated. If None, max
|
||||
length of the model will be used"""
|
||||
"""Model name to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install sentence-transformers`",
|
||||
)
|
||||
|
||||
kwargs["client"] = CrossEncoder(
|
||||
kwargs.get("model_name"),
|
||||
max_length=kwargs.get("max_length"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
query_content_pairs = [[query, candidate] for candidate in candidates]
|
||||
_model = cast(CrossEncoder, self.client)
|
||||
rank_scores = _model.predict(sentences=query_content_pairs)
|
||||
return rank_scores.tolist()
|
||||
|
||||
|
||||
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""OpenAPI Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/v1/beta/relevance",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="bge-reranker-base", description="The name of the model to use."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
if "session" not in kwargs: # noqa: SIM401
|
||||
session = requests.Session()
|
||||
else:
|
||||
session = kwargs["session"]
|
||||
api_key = kwargs.get("api_key")
|
||||
if api_key:
|
||||
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
response = self.session.post( # type: ignore
|
||||
self.api_url, json=data, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
if "data" not in response_data:
|
||||
raise RuntimeError(response_data["detail"])
|
||||
return response_data["data"]
|
@ -1,4 +1,5 @@
|
||||
"""Embedding retriever."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
@ -207,7 +208,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"rerank_cls": self._rerank.__class__.__name__,
|
||||
},
|
||||
):
|
||||
new_candidates_with_score = self._rerank.rank(
|
||||
new_candidates_with_score = await self._rerank.arank(
|
||||
new_candidates_with_score, query
|
||||
)
|
||||
return new_candidates_with_score
|
||||
|
@ -1,9 +1,10 @@
|
||||
"""Rerank module for RAG retriever."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, RerankEmbeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@ -39,6 +40,24 @@ class Ranker(ABC):
|
||||
List[Chunk]
|
||||
"""
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Return top k chunks after ranker.
|
||||
|
||||
Rank algorithm implementation return topk documents by candidates
|
||||
similarity score
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
query: Optional[str]
|
||||
Return:
|
||||
List[Chunk]
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.rank, candidates_with_scores, query
|
||||
)
|
||||
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""Filter duplicate candidates documents."""
|
||||
candidates_with_scores = sorted(
|
||||
@ -52,6 +71,18 @@ class Ranker(ABC):
|
||||
visited_docs.add(candidate_chunk.content)
|
||||
return new_candidates
|
||||
|
||||
def _rerank_with_scores(
|
||||
self, candidates_with_scores: List[Chunk], rank_scores: List[float]
|
||||
) -> List[Chunk]:
|
||||
"""Rerank candidates with scores."""
|
||||
for candidate, score in zip(candidates_with_scores, rank_scores):
|
||||
candidate.score = float(score)
|
||||
|
||||
new_candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Default Ranker"),
|
||||
@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker):
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
|
||||
class RerankEmbeddingsRanker(Ranker):
|
||||
"""Rerank Embeddings Ranker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rerank_embeddings: RerankEmbeddings,
|
||||
topk: int = 4,
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""Rerank Embeddings rank algorithm implementation."""
|
||||
self._model = rerank_embeddings
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = self._model.predict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = await self._model.apredict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
@ -55,7 +55,6 @@ curl http://127.0.0.1:8100/api/v1/embeddings \
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Verify via OpenAI SDK
|
||||
|
||||
```bash
|
||||
@ -72,3 +71,49 @@ completion = openai.ChatCompletion.create(
|
||||
print(completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
## (Experimental) Rerank Open API
|
||||
|
||||
The rerank API is an experimental feature that can be used to rerank the candidate list.
|
||||
|
||||
1. Use cURL to verify the rerank API.
|
||||
```bash
|
||||
curl http://127.0.0.1:8100/api/v1/beta/relevance \
|
||||
-H "Authorization: Bearer EMPTY" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "bge-reranker-base",
|
||||
"query": "what is awel talk about?",
|
||||
"documents": [
|
||||
"Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.",
|
||||
"Autonomous agents have long been a research focus in academic and industry communities",
|
||||
"AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.",
|
||||
"Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla."
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
2. Use python to verify the rerank API.
|
||||
```python
|
||||
from dbgpt.rag.embedding import OpenAPIRerankEmbeddings
|
||||
|
||||
rerank = OpenAPIRerankEmbeddings(api_key="EMPTY", model_name="bge-reranker-base")
|
||||
rerank.predict(
|
||||
query="what is awel talk about?",
|
||||
candidates=[
|
||||
"Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.",
|
||||
"Autonomous agents have long been a research focus in academic and industry communities",
|
||||
"AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.",
|
||||
"Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla."
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
The output is as follows:
|
||||
```bash
|
||||
[
|
||||
0.9685816764831543,
|
||||
3.7338297261158004e-05,
|
||||
0.03692878410220146,
|
||||
3.73825132555794e-05
|
||||
]
|
||||
```
|
||||
|
@ -3,21 +3,21 @@
|
||||
## Install command line tools
|
||||
All the following operations are completed through the `dbgpt` command. To use the `dbgpt` command, you first need to install the `DB-GPT` project. You can install it through the following command
|
||||
|
||||
```python
|
||||
```shell
|
||||
$ pip install -e ".[default]"
|
||||
```
|
||||
It can also be used in script mode
|
||||
```python
|
||||
```shell
|
||||
$ python pilot/scripts/cli_scripts.py
|
||||
```
|
||||
|
||||
## Start Model Controller
|
||||
```python
|
||||
```shell
|
||||
$ dbgpt start controller
|
||||
```
|
||||
|
||||
## View log
|
||||
```python
|
||||
```shell
|
||||
$ docker logs db-gpt-webserver-1 -f
|
||||
```
|
||||
By default, `Model Server` will start on port `8000`
|
||||
@ -28,7 +28,7 @@ By default, `Model Server` will start on port `8000`
|
||||
Start `chatglm2-6b` model Worker
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start worker --model_name chatglm2-6b \
|
||||
--model_path /app/models/chatglm2-6b \
|
||||
--port 8001 \
|
||||
@ -40,7 +40,7 @@ dbgpt start worker --model_name chatglm2-6b \
|
||||
Start `vicuna-13b-v1.5` model Worker
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start worker --model_name vicuna-13b-v1.5 \
|
||||
--model_path /app/models/vicuna-13b-v1.5 \
|
||||
--port 8002 \
|
||||
@ -52,9 +52,9 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \
|
||||
:::
|
||||
|
||||
|
||||
## Start the embedding model service
|
||||
## Start Embedding Model Worker
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start worker --model_name text2vec \
|
||||
--model_path /app/models/text2vec-large-chinese \
|
||||
--worker_type text2vec \
|
||||
@ -66,24 +66,41 @@ dbgpt start worker --model_name text2vec \
|
||||
|
||||
:::
|
||||
|
||||
## Start Reranking Model Worker
|
||||
|
||||
```shell
|
||||
dbgpt start worker --worker_type text2vec \
|
||||
--rerank \
|
||||
--model_path /app/models/bge-reranker-base \
|
||||
--model_name bge-reranker-base \
|
||||
--port 8004 \
|
||||
--controller_addr http://127.0.0.1:8000
|
||||
```
|
||||
:::info note
|
||||
⚠️ Make sure to use your own model name and model path.
|
||||
|
||||
:::
|
||||
|
||||
:::tip
|
||||
View and inspect deployed models
|
||||
:::
|
||||
|
||||
|
||||
```python
|
||||
```shell
|
||||
$ dbgpt model list
|
||||
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
|
||||
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
|
||||
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
|
||||
+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat |
|
||||
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 |
|
||||
| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 |
|
||||
| WorkerManager | service | 172.21.0.5 | 8004 | True | True | | 2023-09-12T23:04:24.598439 |
|
||||
| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 |
|
||||
| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 |
|
||||
| bge-reranker-base | text2vec | 172.21.0.5 | 8004 | True | True | | 2024-05-15T11:36:12.935012 |
|
||||
+-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+
|
||||
```
|
||||
|
||||
|
||||
@ -91,13 +108,13 @@ $ dbgpt model list
|
||||
|
||||
The model service deployed as above can be used through dbgpt_server. First modify the `.env` configuration file to change the connection model address
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start webserver --light
|
||||
```
|
||||
|
||||
## Start Webserver
|
||||
|
||||
```python
|
||||
```shell
|
||||
LLM_MODEL=vicuna-13b-v1.5
|
||||
# The current default MODEL_SERVER address is the address of the Model Controller
|
||||
MODEL_SERVER=http://127.0.0.1:8000
|
||||
@ -106,7 +123,7 @@ MODEL_SERVER=http://127.0.0.1:8000
|
||||
|
||||
|
||||
Or it can be started directly by command to formulate the model.
|
||||
```python
|
||||
```shell
|
||||
LLM_MODEL=chatglm2-6b dbgpt start webserver --light
|
||||
```
|
||||
|
||||
@ -118,7 +135,7 @@ For more information about the use of the command line, you can view the command
|
||||
View dbgpt help `dbgpt --help`
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt --help
|
||||
|
||||
Already connect 'dbgpt'
|
||||
@ -143,7 +160,7 @@ Commands:
|
||||
Check the dbgpt start command `dbgpt start --help`
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start --help
|
||||
|
||||
Already connect 'dbgpt'
|
||||
@ -166,7 +183,7 @@ Commands:
|
||||
View the dbgpt start model service help command `dbgpt start worker --help`
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt start worker --help
|
||||
|
||||
Already connect 'dbgpt'
|
||||
@ -239,7 +256,7 @@ Options:
|
||||
View dbgpt model service related commands `dbgpt model --help`
|
||||
:::
|
||||
|
||||
```python
|
||||
```shell
|
||||
dbgpt model --help
|
||||
|
||||
|
||||
|
@ -13,4 +13,5 @@ types-requests
|
||||
types-beautifulsoup4
|
||||
types-Markdown
|
||||
types-tqdm
|
||||
pandas-stubs
|
||||
pandas-stubs
|
||||
types-psutil
|
Loading…
Reference in New Issue
Block a user