diff --git a/.env.template b/.env.template index ad9c412a4..799113fc2 100644 --- a/.env.template +++ b/.env.template @@ -97,6 +97,23 @@ KNOWLEDGE_SEARCH_REWRITE=False # proxy_http_openapi_proxy_api_key=1dce29a6d66b4e2dbfec67044edbb924 # proxy_http_openapi_proxy_backend=text2vec +#*******************************************************************# +#** RERANK SETTINGS **# +#*******************************************************************# +## Rerank model +# RERANK_MODEL=bge-reranker-base +## If you not set RERANK_MODEL_PATH, DB-GPT will read the model path from EMBEDDING_MODEL_CONFIG based on the RERANK_MODEL. +# RERANK_MODEL_PATH= +## The number of rerank results to return +# RERANK_TOP_K=3 + +## Common HTTP rerank model +# RERANK_MODEL=rerank_proxy_http_openapi +# rerank_proxy_http_openapi_proxy_server_url=http://127.0.0.1:8100/api/v1/beta/relevance +# rerank_proxy_http_openapi_proxy_api_key={your-api-key} +# rerank_proxy_http_openapi_proxy_backend=bge-reranker-base + + #*******************************************************************# diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 223314d87..60c3908f7 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -247,6 +247,10 @@ class Config(metaclass=Singleton): ### EMBEDDING Configuration self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec") + # Rerank model configuration + self.RERANK_MODEL = os.getenv("RERANK_MODEL") + self.RERANK_MODEL_PATH = os.getenv("RERANK_MODEL_PATH") + self.RERANK_TOP_K = int(os.getenv("RERANK_TOP_K", 3)) self.KNOWLEDGE_CHUNK_SIZE = int(os.getenv("KNOWLEDGE_CHUNK_SIZE", 100)) self.KNOWLEDGE_CHUNK_OVERLAP = int(os.getenv("KNOWLEDGE_CHUNK_OVERLAP", 50)) self.KNOWLEDGE_SEARCH_TOP_SIZE = int(os.getenv("KNOWLEDGE_SEARCH_TOP_SIZE", 5)) diff --git a/dbgpt/agent/expand/resources/host_tool.py b/dbgpt/agent/expand/resources/host_tool.py new file mode 100644 index 000000000..562dcc1cd --- /dev/null +++ b/dbgpt/agent/expand/resources/host_tool.py @@ -0,0 +1,45 @@ +"""Host tool resource module.""" + +from ...resource.tool.base import tool + + +@tool(description="Get current host CPU status.") +def get_current_host_cpu_status() -> str: + """Get current host CPU status.""" + import platform + + import psutil + + cpu_architecture = platform.machine() + cpu_count_physical = psutil.cpu_count(logical=False) + cpu_count_logical = psutil.cpu_count(logical=True) + cpu_usage = psutil.cpu_percent(interval=1) + return ( + f"CPU Architecture: {cpu_architecture}\n" + f"Physical CPU Cores: {cpu_count_physical}\n" + f"Logical CPU Cores: {cpu_count_logical}\n" + f"CPU Usage: {cpu_usage}%" + ) + + +@tool(description="Get current host memory status.") +def get_current_host_memory_status() -> str: + """Get current host memory status.""" + import psutil + + memory = psutil.virtual_memory() + return ( + f"Total: {memory.total / (1024**3):.2f} GB\n" + f"Available: {memory.available / (1024**3):.2f} GB\n" + f"Used: {memory.used / (1024**3):.2f} GB\n" + f"Percent: {memory.percent}%" + ) + + +@tool(description="Get current host system load.") +def get_current_host_system_load() -> str: + """Get current host system load.""" + import os + + load1, load5, load15 = os.getloadavg() + return f"System load average: {load1:.2f}, {load5:.2f}, {load15:.2f}" diff --git a/dbgpt/app/base.py b/dbgpt/app/base.py index 885da4871..e5a371870 100644 --- a/dbgpt/app/base.py +++ b/dbgpt/app/base.py @@ -227,6 +227,14 @@ class WebServerParameters(BaseParameters): "text2vec --model_name xxx --model_path xxx`" }, ) + remote_rerank: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to enable remote rerank models. If it is True, you need" + " to start a rerank model through `dbgpt start worker --worker_type " + "text2vec --rerank --model_name xxx --model_path xxx`" + }, + ) log_level: Optional[str] = field( default=None, metadata={ diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 3c4cb95a8..3ef08d4bc 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from typing import Optional from dbgpt._private.config import Config from dbgpt.app.base import WebServerParameters @@ -18,9 +19,14 @@ def initialize_components( system_app: SystemApp, embedding_model_name: str, embedding_model_path: str, + rerank_model_name: Optional[str] = None, + rerank_model_path: Optional[str] = None, ): # Lazy import to avoid high time cost - from dbgpt.app.initialization.embedding_component import _initialize_embedding_model + from dbgpt.app.initialization.embedding_component import ( + _initialize_embedding_model, + _initialize_rerank_model, + ) from dbgpt.app.initialization.scheduler import DefaultScheduler from dbgpt.app.initialization.serve_initialization import register_serve_apps from dbgpt.datasource.manages.connector_manager import ConnectorManager @@ -45,6 +51,7 @@ def initialize_components( _initialize_embedding_model( param, system_app, embedding_model_name, embedding_model_path ) + _initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path) _initialize_model_cache(system_app) _initialize_awel(system_app, param) # Initialize resource manager of agent @@ -89,6 +96,11 @@ def _initialize_agent(system_app: SystemApp): def _initialize_resource_manager(system_app: SystemApp): from dbgpt.agent.expand.resources.dbgpt_tool import list_dbgpt_support_models + from dbgpt.agent.expand.resources.host_tool import ( + get_current_host_cpu_status, + get_current_host_memory_status, + get_current_host_system_load, + ) from dbgpt.agent.expand.resources.search_tool import baidu_search from dbgpt.agent.resource.base import ResourceType from dbgpt.agent.resource.manage import get_resource_manager, initialize_resource @@ -104,6 +116,10 @@ def _initialize_resource_manager(system_app: SystemApp): # Register a search tool rm.register_resource(resource_instance=baidu_search) rm.register_resource(resource_instance=list_dbgpt_support_models) + # Register host tools + rm.register_resource(resource_instance=get_current_host_cpu_status) + rm.register_resource(resource_instance=get_current_host_memory_status) + rm.register_resource(resource_instance=get_current_host_system_load) def _initialize_openapi(system_app: SystemApp): diff --git a/dbgpt/app/dbgpt_server.py b/dbgpt/app/dbgpt_server.py index 24c852525..9e7481b48 100644 --- a/dbgpt/app/dbgpt_server.py +++ b/dbgpt/app/dbgpt_server.py @@ -146,11 +146,24 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): embedding_model_name = CFG.EMBEDDING_MODEL embedding_model_path = EMBEDDING_MODEL_CONFIG[CFG.EMBEDDING_MODEL] + rerank_model_name = CFG.RERANK_MODEL + rerank_model_path = None + if rerank_model_name: + rerank_model_path = CFG.RERANK_MODEL_PATH or EMBEDDING_MODEL_CONFIG.get( + rerank_model_name + ) server_init(param, system_app) mount_routers(app) model_start_listener = _create_model_start_listener(system_app) - initialize_components(param, system_app, embedding_model_name, embedding_model_path) + initialize_components( + param, + system_app, + embedding_model_name, + embedding_model_path, + rerank_model_name, + rerank_model_path, + ) system_app.on_init() # Migration db storage, so you db models must be imported before this @@ -161,7 +174,13 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): if not param.light: print("Model Unified Deployment Mode!") if not param.remote_embedding: + # Embedding model is running in the same process, set embedding_model_name + # and embedding_model_path to None embedding_model_name, embedding_model_path = None, None + if not param.remote_rerank: + # Rerank model is running in the same process, set rerank_model_name and + # rerank_model_path to None + rerank_model_name, rerank_model_path = None, None initialize_worker_manager_in_client( app=app, model_name=model_name, @@ -169,6 +188,8 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None): local_port=param.port, embedding_model_name=embedding_model_name, embedding_model_path=embedding_model_path, + rerank_model_name=rerank_model_name, + rerank_model_path=rerank_model_path, start_listener=model_start_listener, system_app=system_app, ) diff --git a/dbgpt/app/initialization/embedding_component.py b/dbgpt/app/initialization/embedding_component.py index 71070d466..3330fb1f4 100644 --- a/dbgpt/app/initialization/embedding_component.py +++ b/dbgpt/app/initialization/embedding_component.py @@ -1,10 +1,14 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any, Optional, Type from dbgpt.component import ComponentType, SystemApp -from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory +from dbgpt.core import Embeddings, RerankEmbeddings +from dbgpt.rag.embedding.embedding_factory import ( + EmbeddingFactory, + RerankEmbeddingFactory, +) if TYPE_CHECKING: from langchain.embeddings.base import Embeddings @@ -32,6 +36,26 @@ def _initialize_embedding_model( ) +def _initialize_rerank_model( + param: "WebServerParameters", + system_app: SystemApp, + rerank_model_name: Optional[str] = None, + rerank_model_path: Optional[str] = None, +): + if not rerank_model_name: + return + if param.remote_rerank: + logger.info("Register remote RemoteRerankEmbeddingFactory") + system_app.register(RemoteRerankEmbeddingFactory, model_name=rerank_model_name) + else: + logger.info(f"Register local LocalRerankEmbeddingFactory") + system_app.register( + LocalRerankEmbeddingFactory, + default_model_name=rerank_model_name, + default_model_path=rerank_model_path, + ) + + class RemoteEmbeddingFactory(EmbeddingFactory): def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: super().__init__(system_app=system_app) @@ -105,3 +129,81 @@ class LocalEmbeddingFactory(EmbeddingFactory): loader = EmbeddingLoader() # Ignore model_name args return loader.load(self._default_model_name, model_params) + + +class RemoteRerankEmbeddingFactory(RerankEmbeddingFactory): + def __init__(self, system_app, model_name: str = None, **kwargs: Any) -> None: + super().__init__(system_app=system_app) + self._default_model_name = model_name + self.kwargs = kwargs + self.system_app = system_app + + def init_app(self, system_app): + self.system_app = system_app + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "RerankEmbeddings": + from dbgpt.model.cluster import WorkerManagerFactory + from dbgpt.model.cluster.embedding.remote_embedding import ( + RemoteRerankEmbeddings, + ) + + if embedding_cls: + raise NotImplementedError + worker_manager = self.system_app.get_component( + ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory + ).create() + return RemoteRerankEmbeddings( + model_name or self._default_model_name, worker_manager + ) + + +class LocalRerankEmbeddingFactory(RerankEmbeddingFactory): + def __init__( + self, + system_app, + default_model_name: str = None, + default_model_path: str = None, + **kwargs: Any, + ) -> None: + super().__init__(system_app=system_app) + self._default_model_name = default_model_name + self._default_model_path = default_model_path + self._kwargs = kwargs + self._model = self._load_model() + + def init_app(self, system_app): + pass + + def create( + self, model_name: str = None, embedding_cls: Type = None + ) -> "RerankEmbeddings": + if embedding_cls: + raise NotImplementedError + return self._model + + def _load_model(self) -> "RerankEmbeddings": + from dbgpt.model.adapter.embeddings_loader import ( + EmbeddingLoader, + _parse_embedding_params, + ) + from dbgpt.model.parameter import ( + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + BaseEmbeddingModelParameters, + EmbeddingModelParameters, + ) + + param_cls = EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + self._default_model_name, EmbeddingModelParameters + ) + model_params: BaseEmbeddingModelParameters = _parse_embedding_params( + model_name=self._default_model_name, + model_path=self._default_model_path, + param_cls=param_cls, + **self._kwargs, + ) + logger.info(model_params) + loader = EmbeddingLoader() + # Ignore model_name args + return loader.load_rerank_model(self._default_model_name, model_params) diff --git a/dbgpt/app/scene/chat_knowledge/v1/chat.py b/dbgpt/app/scene/chat_knowledge/v1/chat.py index bc84f1937..297d9dd04 100644 --- a/dbgpt/app/scene/chat_knowledge/v1/chat.py +++ b/dbgpt/app/scene/chat_knowledge/v1/chat.py @@ -18,6 +18,7 @@ from dbgpt.core import ( MessagesPlaceholder, SystemPromptTemplate, ) +from dbgpt.rag.retriever.rerank import RerankEmbeddingsRanker from dbgpt.rag.retriever.rewrite import QueryRewrite from dbgpt.util.tracer import root_tracer, trace @@ -37,7 +38,10 @@ class ChatKnowledge(BaseChat): - model_name:(str) llm model name - select_param:(str) space name """ - from dbgpt.rag.embedding.embedding_factory import EmbeddingFactory + from dbgpt.rag.embedding.embedding_factory import ( + EmbeddingFactory, + RerankEmbeddingFactory, + ) self.knowledge_space = chat_param["select_param"] chat_param["chat_mode"] = ChatScene.ChatKnowledge @@ -83,10 +87,22 @@ class ChatKnowledge(BaseChat): model_name=self.llm_model, language=CFG.LANGUAGE, ) + reranker = None + retriever_top_k = self.top_k + if CFG.RERANK_MODEL: + rerank_embeddings = RerankEmbeddingFactory.get_instance( + CFG.SYSTEM_APP + ).create() + reranker = RerankEmbeddingsRanker(rerank_embeddings, topk=CFG.RERANK_TOP_K) + if retriever_top_k < CFG.RERANK_TOP_K or retriever_top_k < 20: + # We use reranker, so if the top_k is less than 20, + # we need to set it to 20 + retriever_top_k = max(CFG.RERANK_TOP_K, 20) self.embedding_retriever = EmbeddingRetriever( - top_k=self.top_k, + top_k=retriever_top_k, vector_store_connector=vector_store_connector, query_rewrite=query_rewrite, + rerank=reranker, ) self.prompt_template.template_is_strict = False self.relations = None diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 51dc0c2d2..f45f6503b 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -209,6 +209,11 @@ EMBEDDING_MODEL_CONFIG = { # Common HTTP embedding model "proxy_http_openapi": "proxy_http_openapi", "proxy_ollama": "proxy_ollama", + # Rerank model, rerank mode is a special embedding model + "bge-reranker-base": os.path.join(MODEL_PATH, "bge-reranker-base"), + "bge-reranker-large": os.path.join(MODEL_PATH, "bge-reranker-large"), + # Proxy rerank model + "rerank_proxy_http_openapi": "rerank_proxy_http_openapi", } diff --git a/dbgpt/core/__init__.py b/dbgpt/core/__init__.py index 718a23158..1aac0eab4 100644 --- a/dbgpt/core/__init__.py +++ b/dbgpt/core/__init__.py @@ -7,7 +7,7 @@ from dbgpt.core.interface.cache import ( # noqa: F401 CachePolicy, CacheValue, ) -from dbgpt.core.interface.embeddings import Embeddings # noqa: F401 +from dbgpt.core.interface.embeddings import Embeddings, RerankEmbeddings # noqa: F401 from dbgpt.core.interface.knowledge import Chunk, Document # noqa: F401 from dbgpt.core.interface.llm import ( # noqa: F401 DefaultMessageConverter, @@ -106,6 +106,7 @@ __ALL__ = [ "QuerySpec", "StorageError", "Embeddings", + "RerankEmbeddings", "Chunk", "Document", ] diff --git a/dbgpt/core/interface/embeddings.py b/dbgpt/core/interface/embeddings.py index e3756fd1c..5f5848b61 100644 --- a/dbgpt/core/interface/embeddings.py +++ b/dbgpt/core/interface/embeddings.py @@ -1,9 +1,24 @@ """Interface for embedding models.""" + import asyncio from abc import ABC, abstractmethod from typing import List +class RerankEmbeddings(ABC): + """Interface for rerank models.""" + + @abstractmethod + def predict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the scores of the candidates.""" + + async def apredict(self, query: str, candidates: List[str]) -> List[float]: + """Asynchronously predict the scores of the candidates.""" + return await asyncio.get_running_loop().run_in_executor( + None, self.predict, query, candidates + ) + + class Embeddings(ABC): """Interface for embedding models. diff --git a/dbgpt/core/schema/api.py b/dbgpt/core/schema/api.py index ca8b4c4c4..f5829f589 100644 --- a/dbgpt/core/schema/api.py +++ b/dbgpt/core/schema/api.py @@ -166,6 +166,23 @@ class EmbeddingsResponse(BaseModel): usage: UsageInfo = Field(..., description="Usage info") +class RelevanceRequest(BaseModel): + """Relevance request entity.""" + + model: str = Field(..., description="Rerank model name") + query: str = Field(..., description="Query text") + documents: List[str] = Field(..., description="Document texts") + + +class RelevanceResponse(BaseModel): + """Relevance response entity.""" + + object: str = Field("list", description="Object type") + model: str = Field(..., description="Rerank model name") + data: List[float] = Field(..., description="Data list, relevance scores") + usage: UsageInfo = Field(..., description="Usage info") + + class ModelPermission(BaseModel): """Model permission entity.""" diff --git a/dbgpt/model/adapter/embeddings_loader.py b/dbgpt/model/adapter/embeddings_loader.py index 6b97c9c3d..1731d21e5 100644 --- a/dbgpt/model/adapter/embeddings_loader.py +++ b/dbgpt/model/adapter/embeddings_loader.py @@ -4,7 +4,7 @@ import logging from typing import List, Optional, Type, cast from dbgpt.configs.model_config import get_device -from dbgpt.core import Embeddings +from dbgpt.core import Embeddings, RerankEmbeddings from dbgpt.model.parameter import ( BaseEmbeddingModelParameters, EmbeddingModelParameters, @@ -66,6 +66,38 @@ class EmbeddingLoader: kwargs = param.build_kwargs(model_name=param.model_path) return HuggingFaceEmbeddings(**kwargs) + def load_rerank_model( + self, model_name: str, param: BaseEmbeddingModelParameters + ) -> RerankEmbeddings: + metadata = { + "model_name": model_name, + "run_service": SpanTypeRunName.EMBEDDING_MODEL.value, + "params": _get_dict_from_obj(param), + "sys_infos": _get_dict_from_obj(get_system_info()), + } + with root_tracer.start_span( + "EmbeddingLoader.load_rerank_model", + span_type=SpanType.RUN, + metadata=metadata, + ): + if model_name in ["rerank_proxy_http_openapi"]: + from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings + + proxy_param = cast(ProxyEmbeddingParameters, param) + openapi_param = {} + if proxy_param.proxy_server_url: + openapi_param["api_url"] = proxy_param.proxy_server_url + if proxy_param.proxy_api_key: + openapi_param["api_key"] = proxy_param.proxy_api_key + if proxy_param.proxy_backend: + openapi_param["model_name"] = proxy_param.proxy_backend + return OpenAPIRerankEmbeddings(**openapi_param) + else: + from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings + + kwargs = param.build_kwargs(model_name=param.model_path) + return CrossEncoderRerankEmbeddings(**kwargs) + def _parse_embedding_params( model_name: Optional[str] = None, diff --git a/dbgpt/model/adapter/model_adapter.py b/dbgpt/model/adapter/model_adapter.py index 3bdebe946..8b1231224 100644 --- a/dbgpt/model/adapter/model_adapter.py +++ b/dbgpt/model/adapter/model_adapter.py @@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]: model_path = pre_args.get("model_path") worker_type = pre_args.get("worker_type") model_type = pre_args.get("model_type") - if model_name is None and model_type != ModelType.VLLM: - return None if worker_type == WorkerType.TEXT2VEC: return [ EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( model_name, EmbeddingModelParameters ) ] - + if model_name is None and model_type != ModelType.VLLM: + return None llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type) param_class = llm_adapter.model_param_class() return [param_class] diff --git a/dbgpt/model/cluster/apiserver/api.py b/dbgpt/model/cluster/apiserver/api.py index ab58d75fb..2f7303059 100644 --- a/dbgpt/model/cluster/apiserver/api.py +++ b/dbgpt/model/cluster/apiserver/api.py @@ -3,6 +3,7 @@ Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py """ + import asyncio import json import logging @@ -34,6 +35,8 @@ from dbgpt.core.schema.api import ( ModelCard, ModelList, ModelPermission, + RelevanceRequest, + RelevanceResponse, UsageInfo, ) from dbgpt.model.base import ModelInstance @@ -368,6 +371,28 @@ class APIServer(BaseComponent): } return await worker_manager.embeddings(params) + async def relevance_generate( + self, model: str, query: str, texts: List[str] + ) -> List[float]: + """Generate embeddings + + Args: + model (str): Model name + query (str): Query text + texts (List[str]): Texts to embed + + Returns: + List[List[float]]: The embeddings of texts + """ + worker_manager: WorkerManager = self.get_worker_manager() + params = { + "input": texts, + "model": model, + "query": query, + } + scores = await worker_manager.embeddings(params) + return scores[0] + def get_api_server() -> APIServer: api_server = global_system_app.get_component( @@ -456,6 +481,26 @@ async def create_embeddings( ) +@router.post( + "/v1/beta/relevance", + dependencies=[Depends(check_api_key)], + response_model=RelevanceResponse, +) +async def create_embeddings( + request: RelevanceRequest, api_server: APIServer = Depends(get_api_server) +): + """Generate relevance scores for a query and a list of documents.""" + await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec") + + scores = await api_server.relevance_generate( + request.model, request.query, request.documents + ) + return model_to_dict( + RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()), + exclude_none=True, + ) + + def _initialize_all(controller_addr: str, system_app: SystemApp): from dbgpt.model.cluster.controller.controller import ModelRegistryClient from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory diff --git a/dbgpt/model/cluster/base.py b/dbgpt/model/cluster/base.py index 8d1832b56..4106d6c97 100644 --- a/dbgpt/model/cluster/base.py +++ b/dbgpt/model/cluster/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from dbgpt._private.pydantic import BaseModel from dbgpt.core.interface.message import ModelMessage @@ -31,7 +31,9 @@ class PromptRequest(BaseModel): class EmbeddingsRequest(BaseModel): model: str input: List[str] - span_id: str = None + span_id: Optional[str] = None + query: Optional[str] = None + """For rerank model, query is required""" class CountTokenRequest(BaseModel): diff --git a/dbgpt/model/cluster/embedding/remote_embedding.py b/dbgpt/model/cluster/embedding/remote_embedding.py index 56ede440a..0743dfffd 100644 --- a/dbgpt/model/cluster/embedding/remote_embedding.py +++ b/dbgpt/model/cluster/embedding/remote_embedding.py @@ -1,6 +1,6 @@ from typing import List -from dbgpt.core import Embeddings +from dbgpt.core import Embeddings, RerankEmbeddings from dbgpt.model.cluster.manager_base import WorkerManager @@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings): async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text.""" return await self.aembed_documents([text])[0] + + +class RemoteRerankEmbeddings(RerankEmbeddings): + def __init__(self, model_name: str, worker_manager: WorkerManager) -> None: + self.model_name = model_name + self.worker_manager = worker_manager + + def predict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the scores of the candidates.""" + params = { + "model": self.model_name, + "input": candidates, + "query": query, + } + return self.worker_manager.sync_embeddings(params)[0] + + async def apredict(self, query: str, candidates: List[str]) -> List[float]: + """Asynchronously predict the scores of the candidates.""" + params = { + "model": self.model_name, + "input": candidates, + "query": query, + } + # Use embeddings interface to get scores of ranker + scores = await self.worker_manager.embeddings(params) + # The first element is the scores of the query + return scores[0] diff --git a/dbgpt/model/cluster/worker/embedding_worker.py b/dbgpt/model/cluster/worker/embedding_worker.py index e3ac95f02..a00e32c6f 100644 --- a/dbgpt/model/cluster/worker/embedding_worker.py +++ b/dbgpt/model/cluster/worker/embedding_worker.py @@ -1,7 +1,7 @@ import logging -from typing import Dict, List, Type +from typing import Dict, List, Type, Union -from dbgpt.core import ModelMetadata +from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings from dbgpt.model.adapter.embeddings_loader import ( EmbeddingLoader, _parse_embedding_params, @@ -20,13 +20,12 @@ logger = logging.getLogger(__name__) class EmbeddingsModelWorker(ModelWorker): - def __init__(self) -> None: - from dbgpt.rag.embedding import Embeddings - - self._embeddings_impl: Embeddings = None + def __init__(self, rerank_model: bool = False) -> None: + self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None self._model_params = None self.model_name = None self.model_path = None + self._rerank_model = rerank_model self._loader = EmbeddingLoader() def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: @@ -64,8 +63,17 @@ class EmbeddingsModelWorker(ModelWorker): """Start model worker""" if not model_params: model_params = self.parse_parameters(command_args) + if self._rerank_model: + model_params.rerank = True # type: ignore self._model_params = model_params - self._embeddings_impl = self._loader.load(self.model_name, model_params) + if model_params.is_rerank_model(): + logger.info(f"Load rerank embeddings model: {self.model_name}") + self._embeddings_impl = self._loader.load_rerank_model( + self.model_name, model_params + ) + else: + logger.info(f"Load embeddings model: {self.model_name}") + self._embeddings_impl = self._loader.load(self.model_name, model_params) def __del__(self): self.stop() @@ -96,5 +104,10 @@ class EmbeddingsModelWorker(ModelWorker): def embeddings(self, params: Dict) -> List[List[float]]: model = params.get("model") logger.info(f"Receive embeddings request, model: {model}") - input: List[str] = params["input"] - return self._embeddings_impl.embed_documents(input) + textx: List[str] = params["input"] + if isinstance(self._embeddings_impl, RerankEmbeddings): + query = params["query"] + scores: List[float] = self._embeddings_impl.predict(query, textx) + return [scores] + else: + return self._embeddings_impl.embed_documents(textx) diff --git a/dbgpt/model/cluster/worker/manager.py b/dbgpt/model/cluster/worker/manager.py index b57698af9..9bc2a80c7 100644 --- a/dbgpt/model/cluster/worker/manager.py +++ b/dbgpt/model/cluster/worker/manager.py @@ -952,7 +952,10 @@ def _create_local_model_manager( ) -def _build_worker(worker_params: ModelWorkerParameters): +def _build_worker( + worker_params: ModelWorkerParameters, + ext_worker_kwargs: Optional[Dict[str, Any]] = None, +): worker_class = worker_params.worker_class if worker_class: from dbgpt.util.module_utils import import_from_checked_string @@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters): else: raise Exception("Unsupported worker type: {worker_params.worker_type}") - return worker_cls() + if ext_worker_kwargs: + return worker_cls(**ext_worker_kwargs) + else: + return worker_cls() def _start_local_worker( - worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters + worker_manager: WorkerManagerAdapter, + worker_params: ModelWorkerParameters, + ext_worker_kwargs: Optional[Dict[str, Any]] = None, ): with root_tracer.start_span( "WorkerManager._start_local_worker", @@ -991,7 +999,7 @@ def _start_local_worker( "sys_infos": _get_dict_from_obj(get_system_info()), }, ): - worker = _build_worker(worker_params) + worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs) if not worker_manager.worker_manager: worker_manager.worker_manager = _create_local_model_manager(worker_params) worker_manager.worker_manager.add_worker(worker, worker_params) @@ -1001,6 +1009,7 @@ def _start_local_embedding_worker( worker_manager: WorkerManagerAdapter, embedding_model_name: str = None, embedding_model_path: str = None, + ext_worker_kwargs: Optional[Dict[str, Any]] = None, ): if not embedding_model_name or not embedding_model_path: return @@ -1013,21 +1022,25 @@ def _start_local_embedding_worker( logger.info( f"Start local embedding worker with embedding parameters\n{embedding_worker_params}" ) - _start_local_worker(worker_manager, embedding_worker_params) + _start_local_worker( + worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs + ) def initialize_worker_manager_in_client( app=None, include_router: bool = True, - model_name: str = None, - model_path: str = None, + model_name: Optional[str] = None, + model_path: Optional[str] = None, run_locally: bool = True, - controller_addr: str = None, + controller_addr: Optional[str] = None, local_port: int = 5670, - embedding_model_name: str = None, - embedding_model_path: str = None, - start_listener: Callable[["WorkerManager"], None] = None, - system_app: SystemApp = None, + embedding_model_name: Optional[str] = None, + embedding_model_path: Optional[str] = None, + rerank_model_name: Optional[str] = None, + rerank_model_path: Optional[str] = None, + start_listener: Optional[Callable[["WorkerManager"], None]] = None, + system_app: Optional[SystemApp] = None, ): """Initialize WorkerManager in client. If run_locally is True: @@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client( _start_local_embedding_worker( worker_manager, embedding_model_name, embedding_model_path ) + _start_local_embedding_worker( + worker_manager, + rerank_model_name, + rerank_model_path, + ext_worker_kwargs={"rerank_model": True}, + ) else: from dbgpt.model.cluster.controller.controller import ( ModelRegistryClient, @@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client( if not worker_params.controller_addr: raise ValueError("Controller can`t be None") - controller_addr = worker_params.controller_addr logger.info(f"Worker params: {worker_params}") client = ModelRegistryClient(worker_params.controller_addr) worker_manager.worker_manager = RemoteWorkerManager(client) diff --git a/dbgpt/model/parameter.py b/dbgpt/model/parameter.py index 470d46b9a..f3396d564 100644 --- a/dbgpt/model/parameter.py +++ b/dbgpt/model/parameter.py @@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters): def build_kwargs(self, **kwargs) -> Dict: pass + def is_rerank_model(self) -> bool: + """Check if the model is a rerank model""" + return False + @dataclass class EmbeddingModelParameters(BaseEmbeddingModelParameters): @@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters): }, ) + rerank: Optional[bool] = field( + default=False, metadata={"help": "Whether the model is a rerank model"} + ) + + max_length: Optional[int] = field( + default=None, + metadata={ + "help": "Max length for input sequences. Longer sequences will be " + "truncated. If None, max length of the model will be used, just for rerank" + " model now." + }, + ) + def build_kwargs(self, **kwargs) -> Dict: model_kwargs, encode_kwargs = None, None if self.device: @@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters): encode_kwargs = {"normalize_embeddings": self.normalize_embeddings} if model_kwargs: kwargs["model_kwargs"] = model_kwargs + if self.is_rerank_model(): + kwargs["max_length"] = self.max_length if encode_kwargs: kwargs["encode_kwargs"] = encode_kwargs return kwargs + def is_rerank_model(self) -> bool: + """Check if the model is a rerank model""" + return self.rerank + @dataclass class ModelParameters(BaseModelParameters): @@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): metadata={"help": "Tto support Azure OpenAI Service custom deployment names"}, ) + rerank: Optional[bool] = field( + default=False, metadata={"help": "Whether the model is a rerank model"} + ) + def build_kwargs(self, **kwargs) -> Dict: params = { "openai_api_base": self.proxy_server_url, "openai_api_key": self.proxy_api_key, "openai_api_type": self.proxy_api_type if self.proxy_api_type else None, - "openai_api_version": self.proxy_api_version - if self.proxy_api_version - else None, + "openai_api_version": ( + self.proxy_api_version if self.proxy_api_version else None + ), "model": self.proxy_backend, - "deployment": self.proxy_deployment - if self.proxy_deployment - else self.proxy_backend, + "deployment": ( + self.proxy_deployment if self.proxy_deployment else self.proxy_backend + ), } for k, v in kwargs: params[k] = v return params + def is_rerank_model(self) -> bool: + """Check if the model is a rerank model""" + return self.rerank + _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { - ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama", + ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi," + "proxy_ollama,rerank_proxy_http_openapi", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} diff --git a/dbgpt/rag/embedding/__init__.py b/dbgpt/rag/embedding/__init__.py index c60f01d6e..ce38d77e1 100644 --- a/dbgpt/rag/embedding/__init__.py +++ b/dbgpt/rag/embedding/__init__.py @@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401 OllamaEmbeddings, OpenAPIEmbeddings, ) +from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401 __ALL__ = [ "Embeddings", @@ -28,4 +29,6 @@ __ALL__ = [ "DefaultEmbeddingFactory", "EmbeddingFactory", "WrappedEmbeddingFactory", + "CrossEncoderRerankEmbeddings", + "OpenAPIRerankEmbeddings", ] diff --git a/dbgpt/rag/embedding/embedding_factory.py b/dbgpt/rag/embedding/embedding_factory.py index 659ebe868..deca1d93e 100644 --- a/dbgpt/rag/embedding/embedding_factory.py +++ b/dbgpt/rag/embedding/embedding_factory.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional, Type from dbgpt.component import BaseComponent, SystemApp -from dbgpt.core import Embeddings +from dbgpt.core import Embeddings, RerankEmbeddings from dbgpt.core.awel import DAGVar from dbgpt.core.awel.flow import ResourceCategory, register_resource from dbgpt.util.i18n_utils import _ @@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC): """ +class RerankEmbeddingFactory(BaseComponent, ABC): + """Class for RerankEmbeddingFactory.""" + + name = "rerank_embedding_factory" + + @abstractmethod + def create( + self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None + ) -> RerankEmbeddings: + """Create an embedding instance. + + Args: + model_name (str): The model name. + embedding_cls (Type): The embedding class. + + Returns: + RerankEmbeddings: The embedding instance. + """ + + class DefaultEmbeddingFactory(EmbeddingFactory): """The default embedding factory.""" diff --git a/dbgpt/rag/embedding/embeddings.py b/dbgpt/rag/embedding/embeddings.py index b51456313..2457cd8f7 100644 --- a/dbgpt/rag/embedding/embeddings.py +++ b/dbgpt/rag/embedding/embeddings.py @@ -1,6 +1,5 @@ """Embedding implementations.""" - from typing import Any, Dict, List, Optional import aiohttp diff --git a/dbgpt/rag/embedding/rerank.py b/dbgpt/rag/embedding/rerank.py new file mode 100644 index 000000000..5d901f7b9 --- /dev/null +++ b/dbgpt/rag/embedding/rerank.py @@ -0,0 +1,131 @@ +"""Re-rank embeddings.""" + +from typing import Any, Dict, List, Optional, cast + +import aiohttp +import requests + +from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field +from dbgpt.core import RerankEmbeddings + + +class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings): + """CrossEncoder Rerank Embeddings.""" + + model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=()) + + client: Any #: :meta private: + model_name: str = "BAAI/bge-reranker-base" + max_length: Optional[int] = None + """Max length for input sequences. Longer sequences will be truncated. If None, max + length of the model will be used""" + """Model name to use.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Keyword arguments to pass to the model.""" + + def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" + try: + from sentence_transformers import CrossEncoder + except ImportError: + raise ImportError( + "please `pip install sentence-transformers`", + ) + + kwargs["client"] = CrossEncoder( + kwargs.get("model_name"), + max_length=kwargs.get("max_length"), + **kwargs.get("model_kwargs"), + ) + super().__init__(**kwargs) + + def predict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the rank scores of the candidates. + + Args: + query: The query text. + candidates: The list of candidate texts. + + Returns: + List[float]: The rank scores of the candidates. + """ + from sentence_transformers import CrossEncoder + + query_content_pairs = [[query, candidate] for candidate in candidates] + _model = cast(CrossEncoder, self.client) + rank_scores = _model.predict(sentences=query_content_pairs) + return rank_scores.tolist() + + +class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings): + """OpenAPI Rerank Embeddings.""" + + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + + api_url: str = Field( + default="http://localhost:8100/v1/beta/relevance", + description="The URL of the embeddings API.", + ) + api_key: Optional[str] = Field( + default=None, description="The API key for the embeddings API." + ) + model_name: str = Field( + default="bge-reranker-base", description="The name of the model to use." + ) + timeout: int = Field( + default=60, description="The timeout for the request in seconds." + ) + + session: Optional[requests.Session] = None + + def __init__(self, **kwargs): + """Initialize the OpenAPIEmbeddings.""" + try: + import requests + except ImportError: + raise ValueError( + "The requests python package is not installed. " + "Please install it with `pip install requests`" + ) + if "session" not in kwargs: # noqa: SIM401 + session = requests.Session() + else: + session = kwargs["session"] + api_key = kwargs.get("api_key") + if api_key: + session.headers.update({"Authorization": f"Bearer {api_key}"}) + kwargs["session"] = session + super().__init__(**kwargs) + + def predict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the rank scores of the candidates. + + Args: + query: The query text. + candidates: The list of candidate texts. + + Returns: + List[float]: The rank scores of the candidates. + """ + if not candidates: + return [] + data = {"model": self.model_name, "query": query, "documents": candidates} + response = self.session.post( # type: ignore + self.api_url, json=data, timeout=self.timeout + ) + response.raise_for_status() + return response.json()["data"] + + async def apredict(self, query: str, candidates: List[str]) -> List[float]: + """Predict the rank scores of the candidates asynchronously.""" + headers = {"Authorization": f"Bearer {self.api_key}"} + async with aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) as session: + data = {"model": self.model_name, "query": query, "documents": candidates} + async with session.post(self.api_url, json=data) as resp: + resp.raise_for_status() + response_data = await resp.json() + if "data" not in response_data: + raise RuntimeError(response_data["detail"]) + return response_data["data"] diff --git a/dbgpt/rag/retriever/embedding.py b/dbgpt/rag/retriever/embedding.py index 0ce896761..ae53d6f68 100644 --- a/dbgpt/rag/retriever/embedding.py +++ b/dbgpt/rag/retriever/embedding.py @@ -1,4 +1,5 @@ """Embedding retriever.""" + from functools import reduce from typing import Any, Dict, List, Optional, cast @@ -207,7 +208,7 @@ class EmbeddingRetriever(BaseRetriever): "rerank_cls": self._rerank.__class__.__name__, }, ): - new_candidates_with_score = self._rerank.rank( + new_candidates_with_score = await self._rerank.arank( new_candidates_with_score, query ) return new_candidates_with_score diff --git a/dbgpt/rag/retriever/rerank.py b/dbgpt/rag/retriever/rerank.py index fde7cf550..2eca8b622 100644 --- a/dbgpt/rag/retriever/rerank.py +++ b/dbgpt/rag/retriever/rerank.py @@ -1,9 +1,10 @@ """Rerank module for RAG retriever.""" +import asyncio from abc import ABC, abstractmethod from typing import Callable, List, Optional -from dbgpt.core import Chunk +from dbgpt.core import Chunk, RerankEmbeddings from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.util.i18n_utils import _ @@ -39,6 +40,24 @@ class Ranker(ABC): List[Chunk] """ + async def arank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: + """Return top k chunks after ranker. + + Rank algorithm implementation return topk documents by candidates + similarity score + + Args: + candidates_with_scores: List[Tuple] + query: Optional[str] + Return: + List[Chunk] + """ + return await asyncio.get_running_loop().run_in_executor( + None, self.rank, candidates_with_scores, query + ) + def _filter(self, candidates_with_scores: List) -> List[Chunk]: """Filter duplicate candidates documents.""" candidates_with_scores = sorted( @@ -52,6 +71,18 @@ class Ranker(ABC): visited_docs.add(candidate_chunk.content) return new_candidates + def _rerank_with_scores( + self, candidates_with_scores: List[Chunk], rank_scores: List[float] + ) -> List[Chunk]: + """Rerank candidates with scores.""" + for candidate, score in zip(candidates_with_scores, rank_scores): + candidate.score = float(score) + + new_candidates_with_scores = sorted( + candidates_with_scores, key=lambda x: x.score, reverse=True + ) + return new_candidates_with_scores + @register_resource( _("Default Ranker"), @@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker): candidates_with_scores, key=lambda x: x.score, reverse=True ) return new_candidates_with_scores[: self.topk] + + +class RerankEmbeddingsRanker(Ranker): + """Rerank Embeddings Ranker.""" + + def __init__( + self, + rerank_embeddings: RerankEmbeddings, + topk: int = 4, + rank_fn: Optional[RANK_FUNC] = None, + ): + """Rerank Embeddings rank algorithm implementation.""" + self._model = rerank_embeddings + super().__init__(topk, rank_fn) + + def rank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: + """Rerank Embeddings rank algorithm implementation. + + Args: + candidates_with_scores: List[Chunk], candidates with scores + query: Optional[str], query text + Returns: + List[Chunk], reranked candidates + """ + if not candidates_with_scores or not query: + return candidates_with_scores + + contents = [candidate.content for candidate in candidates_with_scores] + rank_scores = self._model.predict(query, contents) + new_candidates_with_scores = self._rerank_with_scores( + candidates_with_scores, rank_scores + ) + return new_candidates_with_scores[: self.topk] + + async def arank( + self, candidates_with_scores: List[Chunk], query: Optional[str] = None + ) -> List[Chunk]: + """Rerank Embeddings rank algorithm implementation. + + Args: + candidates_with_scores: List[Chunk], candidates with scores + query: Optional[str], query text + Returns: + List[Chunk], reranked candidates + """ + if not candidates_with_scores or not query: + return candidates_with_scores + + contents = [candidate.content for candidate in candidates_with_scores] + rank_scores = await self._model.apredict(query, contents) + new_candidates_with_scores = self._rerank_with_scores( + candidates_with_scores, rank_scores + ) + return new_candidates_with_scores[: self.topk] diff --git a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md index 079d4d87c..a1a618eb1 100644 --- a/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md +++ b/docs/docs/installation/advanced_usage/OpenAI_SDK_call.md @@ -55,7 +55,6 @@ curl http://127.0.0.1:8100/api/v1/embeddings \ ``` - ## Verify via OpenAI SDK ```bash @@ -72,3 +71,49 @@ completion = openai.ChatCompletion.create( print(completion.choices[0].message.content) ``` +## (Experimental) Rerank Open API + +The rerank API is an experimental feature that can be used to rerank the candidate list. + +1. Use cURL to verify the rerank API. +```bash +curl http://127.0.0.1:8100/api/v1/beta/relevance \ +-H "Authorization: Bearer EMPTY" \ +-H "Content-Type: application/json" \ +-d '{ + "model": "bge-reranker-base", + "query": "what is awel talk about?", + "documents": [ + "Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.", + "Autonomous agents have long been a research focus in academic and industry communities", + "AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.", + "Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla." + ] +}' +``` + +2. Use python to verify the rerank API. +```python +from dbgpt.rag.embedding import OpenAPIRerankEmbeddings + +rerank = OpenAPIRerankEmbeddings(api_key="EMPTY", model_name="bge-reranker-base") +rerank.predict( + query="what is awel talk about?", + candidates=[ + "Agentic Workflow Expression Language(AWEL) is a set of intelligent agent workflow expression language specially designed for large model application development.", + "Autonomous agents have long been a research focus in academic and industry communities", + "AWEL is divided into three levels in deign, namely the operator layer, AgentFream layer and DSL layer.", + "Elon musk is a famous entrepreneur and inventor, he is the founder of SpaceX and Tesla." + ] +) +``` + +The output is as follows: +```bash +[ + 0.9685816764831543, + 3.7338297261158004e-05, + 0.03692878410220146, + 3.73825132555794e-05 +] +``` diff --git a/docs/docs/installation/model_service/cluster.md b/docs/docs/installation/model_service/cluster.md index 13919495d..11ee6e0b4 100644 --- a/docs/docs/installation/model_service/cluster.md +++ b/docs/docs/installation/model_service/cluster.md @@ -3,21 +3,21 @@ ## Install command line tools All the following operations are completed through the `dbgpt` command. To use the `dbgpt` command, you first need to install the `DB-GPT` project. You can install it through the following command -```python +```shell $ pip install -e ".[default]" ``` It can also be used in script mode -```python +```shell $ python pilot/scripts/cli_scripts.py ``` ## Start Model Controller -```python +```shell $ dbgpt start controller ``` ## View log -```python +```shell $ docker logs db-gpt-webserver-1 -f ``` By default, `Model Server` will start on port `8000` @@ -28,7 +28,7 @@ By default, `Model Server` will start on port `8000` Start `chatglm2-6b` model Worker ::: -```python +```shell dbgpt start worker --model_name chatglm2-6b \ --model_path /app/models/chatglm2-6b \ --port 8001 \ @@ -40,7 +40,7 @@ dbgpt start worker --model_name chatglm2-6b \ Start `vicuna-13b-v1.5` model Worker ::: -```python +```shell dbgpt start worker --model_name vicuna-13b-v1.5 \ --model_path /app/models/vicuna-13b-v1.5 \ --port 8002 \ @@ -52,9 +52,9 @@ dbgpt start worker --model_name vicuna-13b-v1.5 \ ::: -## Start the embedding model service +## Start Embedding Model Worker -```python +```shell dbgpt start worker --model_name text2vec \ --model_path /app/models/text2vec-large-chinese \ --worker_type text2vec \ @@ -66,24 +66,41 @@ dbgpt start worker --model_name text2vec \ ::: +## Start Reranking Model Worker + +```shell +dbgpt start worker --worker_type text2vec \ +--rerank \ +--model_path /app/models/bge-reranker-base \ +--model_name bge-reranker-base \ +--port 8004 \ +--controller_addr http://127.0.0.1:8000 +``` +:::info note +⚠️ Make sure to use your own model name and model path. + +::: + :::tip View and inspect deployed models ::: -```python +```shell $ dbgpt model list -+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ -| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat | -+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ -| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 | -| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 | -| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 | -| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 | -| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 | -| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 | -+-----------------+------------+------------+------+---------+---------+-----------------+----------------------------+ ++-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+ +| Model Name | Model Type | Host | Port | Healthy | Enabled | Prompt Template | Last Heartbeat | ++-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+ +| chatglm2-6b | llm | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.287654 | +| WorkerManager | service | 172.17.0.2 | 8001 | True | True | | 2023-09-12T23:04:31.286668 | +| WorkerManager | service | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.845617 | +| WorkerManager | service | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.598439 | +| WorkerManager | service | 172.21.0.5 | 8004 | True | True | | 2023-09-12T23:04:24.598439 | +| text2vec | text2vec | 172.17.0.2 | 8003 | True | True | | 2023-09-12T23:04:29.844796 | +| vicuna-13b-v1.5 | llm | 172.17.0.2 | 8002 | True | True | | 2023-09-12T23:04:24.597775 | +| bge-reranker-base | text2vec | 172.21.0.5 | 8004 | True | True | | 2024-05-15T11:36:12.935012 | ++-------------------+------------+------------+------+---------+---------+-----------------+----------------------------+ ``` @@ -91,13 +108,13 @@ $ dbgpt model list The model service deployed as above can be used through dbgpt_server. First modify the `.env` configuration file to change the connection model address -```python +```shell dbgpt start webserver --light ``` ## Start Webserver -```python +```shell LLM_MODEL=vicuna-13b-v1.5 # The current default MODEL_SERVER address is the address of the Model Controller MODEL_SERVER=http://127.0.0.1:8000 @@ -106,7 +123,7 @@ MODEL_SERVER=http://127.0.0.1:8000 Or it can be started directly by command to formulate the model. -```python +```shell LLM_MODEL=chatglm2-6b dbgpt start webserver --light ``` @@ -118,7 +135,7 @@ For more information about the use of the command line, you can view the command View dbgpt help `dbgpt --help` ::: -```python +```shell dbgpt --help Already connect 'dbgpt' @@ -143,7 +160,7 @@ Commands: Check the dbgpt start command `dbgpt start --help` ::: -```python +```shell dbgpt start --help Already connect 'dbgpt' @@ -166,7 +183,7 @@ Commands: View the dbgpt start model service help command `dbgpt start worker --help` ::: -```python +```shell dbgpt start worker --help Already connect 'dbgpt' @@ -239,7 +256,7 @@ Options: View dbgpt model service related commands `dbgpt model --help` ::: -```python +```shell dbgpt model --help diff --git a/requirements/lint-requirements.txt b/requirements/lint-requirements.txt index 733bf89f4..2b57c455d 100644 --- a/requirements/lint-requirements.txt +++ b/requirements/lint-requirements.txt @@ -13,4 +13,5 @@ types-requests types-beautifulsoup4 types-Markdown types-tqdm -pandas-stubs \ No newline at end of file +pandas-stubs +types-psutil \ No newline at end of file