mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-05 19:11:52 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -4,7 +4,7 @@ import logging
|
||||
from typing import List, Optional, Type, cast
|
||||
|
||||
from dbgpt.configs.model_config import get_device
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.parameter import (
|
||||
BaseEmbeddingModelParameters,
|
||||
EmbeddingModelParameters,
|
||||
@@ -66,6 +66,38 @@ class EmbeddingLoader:
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return HuggingFaceEmbeddings(**kwargs)
|
||||
|
||||
def load_rerank_model(
|
||||
self, model_name: str, param: BaseEmbeddingModelParameters
|
||||
) -> RerankEmbeddings:
|
||||
metadata = {
|
||||
"model_name": model_name,
|
||||
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
|
||||
"params": _get_dict_from_obj(param),
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
}
|
||||
with root_tracer.start_span(
|
||||
"EmbeddingLoader.load_rerank_model",
|
||||
span_type=SpanType.RUN,
|
||||
metadata=metadata,
|
||||
):
|
||||
if model_name in ["rerank_proxy_http_openapi"]:
|
||||
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
|
||||
|
||||
proxy_param = cast(ProxyEmbeddingParameters, param)
|
||||
openapi_param = {}
|
||||
if proxy_param.proxy_server_url:
|
||||
openapi_param["api_url"] = proxy_param.proxy_server_url
|
||||
if proxy_param.proxy_api_key:
|
||||
openapi_param["api_key"] = proxy_param.proxy_api_key
|
||||
if proxy_param.proxy_backend:
|
||||
openapi_param["model_name"] = proxy_param.proxy_backend
|
||||
return OpenAPIRerankEmbeddings(**openapi_param)
|
||||
else:
|
||||
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
|
||||
|
||||
kwargs = param.build_kwargs(model_name=param.model_path)
|
||||
return CrossEncoderRerankEmbeddings(**kwargs)
|
||||
|
||||
|
||||
def _parse_embedding_params(
|
||||
model_name: Optional[str] = None,
|
||||
|
@@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
model_type = pre_args.get("model_type")
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
@@ -3,6 +3,7 @@
|
||||
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
@@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
RelevanceRequest,
|
||||
RelevanceResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
from dbgpt.model.base import ModelInstance
|
||||
@@ -368,6 +371,28 @@ class APIServer(BaseComponent):
|
||||
}
|
||||
return await worker_manager.embeddings(params)
|
||||
|
||||
async def relevance_generate(
|
||||
self, model: str, query: str, texts: List[str]
|
||||
) -> List[float]:
|
||||
"""Generate embeddings
|
||||
|
||||
Args:
|
||||
model (str): Model name
|
||||
query (str): Query text
|
||||
texts (List[str]): Texts to embed
|
||||
|
||||
Returns:
|
||||
List[List[float]]: The embeddings of texts
|
||||
"""
|
||||
worker_manager: WorkerManager = self.get_worker_manager()
|
||||
params = {
|
||||
"input": texts,
|
||||
"model": model,
|
||||
"query": query,
|
||||
}
|
||||
scores = await worker_manager.embeddings(params)
|
||||
return scores[0]
|
||||
|
||||
|
||||
def get_api_server() -> APIServer:
|
||||
api_server = global_system_app.get_component(
|
||||
@@ -456,6 +481,26 @@ async def create_embeddings(
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/beta/relevance",
|
||||
dependencies=[Depends(check_api_key)],
|
||||
response_model=RelevanceResponse,
|
||||
)
|
||||
async def create_embeddings(
|
||||
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
|
||||
):
|
||||
"""Generate relevance scores for a query and a list of documents."""
|
||||
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
|
||||
|
||||
scores = await api_server.relevance_generate(
|
||||
request.model, request.query, request.documents
|
||||
)
|
||||
return model_to_dict(
|
||||
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
|
||||
exclude_none=True,
|
||||
)
|
||||
|
||||
|
||||
def _initialize_all(controller_addr: str, system_app: SystemApp):
|
||||
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
|
||||
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
@@ -31,7 +31,9 @@ class PromptRequest(BaseModel):
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
model: str
|
||||
input: List[str]
|
||||
span_id: str = None
|
||||
span_id: Optional[str] = None
|
||||
query: Optional[str] = None
|
||||
"""For rerank model, query is required"""
|
||||
|
||||
|
||||
class CountTokenRequest(BaseModel):
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.model.cluster.manager_base import WorkerManager
|
||||
|
||||
|
||||
@@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text."""
|
||||
return await self.aembed_documents([text])[0]
|
||||
|
||||
|
||||
class RemoteRerankEmbeddings(RerankEmbeddings):
|
||||
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
|
||||
self.model_name = model_name
|
||||
self.worker_manager = worker_manager
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
return self.worker_manager.sync_embeddings(params)[0]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Asynchronously predict the scores of the candidates."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"input": candidates,
|
||||
"query": query,
|
||||
}
|
||||
# Use embeddings interface to get scores of ranker
|
||||
scores = await self.worker_manager.embeddings(params)
|
||||
# The first element is the scores of the query
|
||||
return scores[0]
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
from dbgpt.core import ModelMetadata
|
||||
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
|
||||
from dbgpt.model.adapter.embeddings_loader import (
|
||||
EmbeddingLoader,
|
||||
_parse_embedding_params,
|
||||
@@ -20,13 +20,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
from dbgpt.rag.embedding import Embeddings
|
||||
|
||||
self._embeddings_impl: Embeddings = None
|
||||
def __init__(self, rerank_model: bool = False) -> None:
|
||||
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
|
||||
self._model_params = None
|
||||
self.model_name = None
|
||||
self.model_path = None
|
||||
self._rerank_model = rerank_model
|
||||
self._loader = EmbeddingLoader()
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
@@ -64,8 +63,17 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
"""Start model worker"""
|
||||
if not model_params:
|
||||
model_params = self.parse_parameters(command_args)
|
||||
if self._rerank_model:
|
||||
model_params.rerank = True # type: ignore
|
||||
self._model_params = model_params
|
||||
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||
if model_params.is_rerank_model():
|
||||
logger.info(f"Load rerank embeddings model: {self.model_name}")
|
||||
self._embeddings_impl = self._loader.load_rerank_model(
|
||||
self.model_name, model_params
|
||||
)
|
||||
else:
|
||||
logger.info(f"Load embeddings model: {self.model_name}")
|
||||
self._embeddings_impl = self._loader.load(self.model_name, model_params)
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
@@ -96,5 +104,10 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
model = params.get("model")
|
||||
logger.info(f"Receive embeddings request, model: {model}")
|
||||
input: List[str] = params["input"]
|
||||
return self._embeddings_impl.embed_documents(input)
|
||||
textx: List[str] = params["input"]
|
||||
if isinstance(self._embeddings_impl, RerankEmbeddings):
|
||||
query = params["query"]
|
||||
scores: List[float] = self._embeddings_impl.predict(query, textx)
|
||||
return [scores]
|
||||
else:
|
||||
return self._embeddings_impl.embed_documents(textx)
|
||||
|
@@ -952,7 +952,10 @@ def _create_local_model_manager(
|
||||
)
|
||||
|
||||
|
||||
def _build_worker(worker_params: ModelWorkerParameters):
|
||||
def _build_worker(
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
worker_class = worker_params.worker_class
|
||||
if worker_class:
|
||||
from dbgpt.util.module_utils import import_from_checked_string
|
||||
@@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
|
||||
else:
|
||||
raise Exception("Unsupported worker type: {worker_params.worker_type}")
|
||||
|
||||
return worker_cls()
|
||||
if ext_worker_kwargs:
|
||||
return worker_cls(**ext_worker_kwargs)
|
||||
else:
|
||||
return worker_cls()
|
||||
|
||||
|
||||
def _start_local_worker(
|
||||
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
worker_params: ModelWorkerParameters,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
with root_tracer.start_span(
|
||||
"WorkerManager._start_local_worker",
|
||||
@@ -991,7 +999,7 @@ def _start_local_worker(
|
||||
"sys_infos": _get_dict_from_obj(get_system_info()),
|
||||
},
|
||||
):
|
||||
worker = _build_worker(worker_params)
|
||||
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
|
||||
if not worker_manager.worker_manager:
|
||||
worker_manager.worker_manager = _create_local_model_manager(worker_params)
|
||||
worker_manager.worker_manager.add_worker(worker, worker_params)
|
||||
@@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
|
||||
worker_manager: WorkerManagerAdapter,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not embedding_model_name or not embedding_model_path:
|
||||
return
|
||||
@@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
|
||||
logger.info(
|
||||
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
|
||||
)
|
||||
_start_local_worker(worker_manager, embedding_worker_params)
|
||||
_start_local_worker(
|
||||
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
|
||||
)
|
||||
|
||||
|
||||
def initialize_worker_manager_in_client(
|
||||
app=None,
|
||||
include_router: bool = True,
|
||||
model_name: str = None,
|
||||
model_path: str = None,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
run_locally: bool = True,
|
||||
controller_addr: str = None,
|
||||
controller_addr: Optional[str] = None,
|
||||
local_port: int = 5670,
|
||||
embedding_model_name: str = None,
|
||||
embedding_model_path: str = None,
|
||||
start_listener: Callable[["WorkerManager"], None] = None,
|
||||
system_app: SystemApp = None,
|
||||
embedding_model_name: Optional[str] = None,
|
||||
embedding_model_path: Optional[str] = None,
|
||||
rerank_model_name: Optional[str] = None,
|
||||
rerank_model_path: Optional[str] = None,
|
||||
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
|
||||
system_app: Optional[SystemApp] = None,
|
||||
):
|
||||
"""Initialize WorkerManager in client.
|
||||
If run_locally is True:
|
||||
@@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
|
||||
_start_local_embedding_worker(
|
||||
worker_manager, embedding_model_name, embedding_model_path
|
||||
)
|
||||
_start_local_embedding_worker(
|
||||
worker_manager,
|
||||
rerank_model_name,
|
||||
rerank_model_path,
|
||||
ext_worker_kwargs={"rerank_model": True},
|
||||
)
|
||||
else:
|
||||
from dbgpt.model.cluster.controller.controller import (
|
||||
ModelRegistryClient,
|
||||
@@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
|
||||
|
||||
if not worker_params.controller_addr:
|
||||
raise ValueError("Controller can`t be None")
|
||||
controller_addr = worker_params.controller_addr
|
||||
logger.info(f"Worker params: {worker_params}")
|
||||
client = ModelRegistryClient(worker_params.controller_addr)
|
||||
worker_manager.worker_manager = RemoteWorkerManager(client)
|
||||
|
@@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters):
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
pass
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
@@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
},
|
||||
)
|
||||
|
||||
rerank: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether the model is a rerank model"}
|
||||
)
|
||||
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Max length for input sequences. Longer sequences will be "
|
||||
"truncated. If None, max length of the model will be used, just for rerank"
|
||||
" model now."
|
||||
},
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
model_kwargs, encode_kwargs = None, None
|
||||
if self.device:
|
||||
@@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
||||
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
|
||||
if model_kwargs:
|
||||
kwargs["model_kwargs"] = model_kwargs
|
||||
if self.is_rerank_model():
|
||||
kwargs["max_length"] = self.max_length
|
||||
if encode_kwargs:
|
||||
kwargs["encode_kwargs"] = encode_kwargs
|
||||
return kwargs
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return self.rerank
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseModelParameters):
|
||||
@@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
||||
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
|
||||
)
|
||||
|
||||
rerank: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether the model is a rerank model"}
|
||||
)
|
||||
|
||||
def build_kwargs(self, **kwargs) -> Dict:
|
||||
params = {
|
||||
"openai_api_base": self.proxy_server_url,
|
||||
"openai_api_key": self.proxy_api_key,
|
||||
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
|
||||
"openai_api_version": self.proxy_api_version
|
||||
if self.proxy_api_version
|
||||
else None,
|
||||
"openai_api_version": (
|
||||
self.proxy_api_version if self.proxy_api_version else None
|
||||
),
|
||||
"model": self.proxy_backend,
|
||||
"deployment": self.proxy_deployment
|
||||
if self.proxy_deployment
|
||||
else self.proxy_backend,
|
||||
"deployment": (
|
||||
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
|
||||
),
|
||||
}
|
||||
for k, v in kwargs:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
def is_rerank_model(self) -> bool:
|
||||
"""Check if the model is a rerank model"""
|
||||
return self.rerank
|
||||
|
||||
|
||||
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
|
||||
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
|
||||
"proxy_ollama,rerank_proxy_http_openapi",
|
||||
}
|
||||
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
||||
|
Reference in New Issue
Block a user