feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -4,7 +4,7 @@ import logging
from typing import List, Optional, Type, cast
from dbgpt.configs.model_config import get_device
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.parameter import (
BaseEmbeddingModelParameters,
EmbeddingModelParameters,
@@ -66,6 +66,38 @@ class EmbeddingLoader:
kwargs = param.build_kwargs(model_name=param.model_path)
return HuggingFaceEmbeddings(**kwargs)
def load_rerank_model(
self, model_name: str, param: BaseEmbeddingModelParameters
) -> RerankEmbeddings:
metadata = {
"model_name": model_name,
"run_service": SpanTypeRunName.EMBEDDING_MODEL.value,
"params": _get_dict_from_obj(param),
"sys_infos": _get_dict_from_obj(get_system_info()),
}
with root_tracer.start_span(
"EmbeddingLoader.load_rerank_model",
span_type=SpanType.RUN,
metadata=metadata,
):
if model_name in ["rerank_proxy_http_openapi"]:
from dbgpt.rag.embedding.rerank import OpenAPIRerankEmbeddings
proxy_param = cast(ProxyEmbeddingParameters, param)
openapi_param = {}
if proxy_param.proxy_server_url:
openapi_param["api_url"] = proxy_param.proxy_server_url
if proxy_param.proxy_api_key:
openapi_param["api_key"] = proxy_param.proxy_api_key
if proxy_param.proxy_backend:
openapi_param["model_name"] = proxy_param.proxy_backend
return OpenAPIRerankEmbeddings(**openapi_param)
else:
from dbgpt.rag.embedding.rerank import CrossEncoderRerankEmbeddings
kwargs = param.build_kwargs(model_name=param.model_path)
return CrossEncoderRerankEmbeddings(**kwargs)
def _parse_embedding_params(
model_name: Optional[str] = None,

View File

@@ -147,15 +147,14 @@ def _dynamic_model_parser() -> Optional[List[Type[BaseModelParameters]]]:
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
if model_name is None and model_type != ModelType.VLLM:
return None
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]

View File

@@ -3,6 +3,7 @@
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/openai_api_server.py
"""
import asyncio
import json
import logging
@@ -34,6 +35,8 @@ from dbgpt.core.schema.api import (
ModelCard,
ModelList,
ModelPermission,
RelevanceRequest,
RelevanceResponse,
UsageInfo,
)
from dbgpt.model.base import ModelInstance
@@ -368,6 +371,28 @@ class APIServer(BaseComponent):
}
return await worker_manager.embeddings(params)
async def relevance_generate(
self, model: str, query: str, texts: List[str]
) -> List[float]:
"""Generate embeddings
Args:
model (str): Model name
query (str): Query text
texts (List[str]): Texts to embed
Returns:
List[List[float]]: The embeddings of texts
"""
worker_manager: WorkerManager = self.get_worker_manager()
params = {
"input": texts,
"model": model,
"query": query,
}
scores = await worker_manager.embeddings(params)
return scores[0]
def get_api_server() -> APIServer:
api_server = global_system_app.get_component(
@@ -456,6 +481,26 @@ async def create_embeddings(
)
@router.post(
"/v1/beta/relevance",
dependencies=[Depends(check_api_key)],
response_model=RelevanceResponse,
)
async def create_embeddings(
request: RelevanceRequest, api_server: APIServer = Depends(get_api_server)
):
"""Generate relevance scores for a query and a list of documents."""
await api_server.get_model_instances_or_raise(request.model, worker_type="text2vec")
scores = await api_server.relevance_generate(
request.model, request.query, request.documents
)
return model_to_dict(
RelevanceResponse(data=scores, model=request.model, usage=UsageInfo()),
exclude_none=True,
)
def _initialize_all(controller_addr: str, system_app: SystemApp):
from dbgpt.model.cluster.controller.controller import ModelRegistryClient
from dbgpt.model.cluster.worker.manager import _DefaultWorkerManagerFactory

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel
from dbgpt.core.interface.message import ModelMessage
@@ -31,7 +31,9 @@ class PromptRequest(BaseModel):
class EmbeddingsRequest(BaseModel):
model: str
input: List[str]
span_id: str = None
span_id: Optional[str] = None
query: Optional[str] = None
"""For rerank model, query is required"""
class CountTokenRequest(BaseModel):

View File

@@ -1,6 +1,6 @@
from typing import List
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.model.cluster.manager_base import WorkerManager
@@ -26,3 +26,30 @@ class RemoteEmbeddings(Embeddings):
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await self.aembed_documents([text])[0]
class RemoteRerankEmbeddings(RerankEmbeddings):
def __init__(self, model_name: str, worker_manager: WorkerManager) -> None:
self.model_name = model_name
self.worker_manager = worker_manager
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
return self.worker_manager.sync_embeddings(params)[0]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Asynchronously predict the scores of the candidates."""
params = {
"model": self.model_name,
"input": candidates,
"query": query,
}
# Use embeddings interface to get scores of ranker
scores = await self.worker_manager.embeddings(params)
# The first element is the scores of the query
return scores[0]

View File

@@ -1,7 +1,7 @@
import logging
from typing import Dict, List, Type
from typing import Dict, List, Type, Union
from dbgpt.core import ModelMetadata
from dbgpt.core import Embeddings, ModelMetadata, RerankEmbeddings
from dbgpt.model.adapter.embeddings_loader import (
EmbeddingLoader,
_parse_embedding_params,
@@ -20,13 +20,12 @@ logger = logging.getLogger(__name__)
class EmbeddingsModelWorker(ModelWorker):
def __init__(self) -> None:
from dbgpt.rag.embedding import Embeddings
self._embeddings_impl: Embeddings = None
def __init__(self, rerank_model: bool = False) -> None:
self._embeddings_impl: Union[Embeddings, RerankEmbeddings, None] = None
self._model_params = None
self.model_name = None
self.model_path = None
self._rerank_model = rerank_model
self._loader = EmbeddingLoader()
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
@@ -64,8 +63,17 @@ class EmbeddingsModelWorker(ModelWorker):
"""Start model worker"""
if not model_params:
model_params = self.parse_parameters(command_args)
if self._rerank_model:
model_params.rerank = True # type: ignore
self._model_params = model_params
self._embeddings_impl = self._loader.load(self.model_name, model_params)
if model_params.is_rerank_model():
logger.info(f"Load rerank embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load_rerank_model(
self.model_name, model_params
)
else:
logger.info(f"Load embeddings model: {self.model_name}")
self._embeddings_impl = self._loader.load(self.model_name, model_params)
def __del__(self):
self.stop()
@@ -96,5 +104,10 @@ class EmbeddingsModelWorker(ModelWorker):
def embeddings(self, params: Dict) -> List[List[float]]:
model = params.get("model")
logger.info(f"Receive embeddings request, model: {model}")
input: List[str] = params["input"]
return self._embeddings_impl.embed_documents(input)
textx: List[str] = params["input"]
if isinstance(self._embeddings_impl, RerankEmbeddings):
query = params["query"]
scores: List[float] = self._embeddings_impl.predict(query, textx)
return [scores]
else:
return self._embeddings_impl.embed_documents(textx)

View File

@@ -952,7 +952,10 @@ def _create_local_model_manager(
)
def _build_worker(worker_params: ModelWorkerParameters):
def _build_worker(
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
worker_class = worker_params.worker_class
if worker_class:
from dbgpt.util.module_utils import import_from_checked_string
@@ -976,11 +979,16 @@ def _build_worker(worker_params: ModelWorkerParameters):
else:
raise Exception("Unsupported worker type: {worker_params.worker_type}")
return worker_cls()
if ext_worker_kwargs:
return worker_cls(**ext_worker_kwargs)
else:
return worker_cls()
def _start_local_worker(
worker_manager: WorkerManagerAdapter, worker_params: ModelWorkerParameters
worker_manager: WorkerManagerAdapter,
worker_params: ModelWorkerParameters,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
with root_tracer.start_span(
"WorkerManager._start_local_worker",
@@ -991,7 +999,7 @@ def _start_local_worker(
"sys_infos": _get_dict_from_obj(get_system_info()),
},
):
worker = _build_worker(worker_params)
worker = _build_worker(worker_params, ext_worker_kwargs=ext_worker_kwargs)
if not worker_manager.worker_manager:
worker_manager.worker_manager = _create_local_model_manager(worker_params)
worker_manager.worker_manager.add_worker(worker, worker_params)
@@ -1001,6 +1009,7 @@ def _start_local_embedding_worker(
worker_manager: WorkerManagerAdapter,
embedding_model_name: str = None,
embedding_model_path: str = None,
ext_worker_kwargs: Optional[Dict[str, Any]] = None,
):
if not embedding_model_name or not embedding_model_path:
return
@@ -1013,21 +1022,25 @@ def _start_local_embedding_worker(
logger.info(
f"Start local embedding worker with embedding parameters\n{embedding_worker_params}"
)
_start_local_worker(worker_manager, embedding_worker_params)
_start_local_worker(
worker_manager, embedding_worker_params, ext_worker_kwargs=ext_worker_kwargs
)
def initialize_worker_manager_in_client(
app=None,
include_router: bool = True,
model_name: str = None,
model_path: str = None,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
run_locally: bool = True,
controller_addr: str = None,
controller_addr: Optional[str] = None,
local_port: int = 5670,
embedding_model_name: str = None,
embedding_model_path: str = None,
start_listener: Callable[["WorkerManager"], None] = None,
system_app: SystemApp = None,
embedding_model_name: Optional[str] = None,
embedding_model_path: Optional[str] = None,
rerank_model_name: Optional[str] = None,
rerank_model_path: Optional[str] = None,
start_listener: Optional[Callable[["WorkerManager"], None]] = None,
system_app: Optional[SystemApp] = None,
):
"""Initialize WorkerManager in client.
If run_locally is True:
@@ -1063,6 +1076,12 @@ def initialize_worker_manager_in_client(
_start_local_embedding_worker(
worker_manager, embedding_model_name, embedding_model_path
)
_start_local_embedding_worker(
worker_manager,
rerank_model_name,
rerank_model_path,
ext_worker_kwargs={"rerank_model": True},
)
else:
from dbgpt.model.cluster.controller.controller import (
ModelRegistryClient,
@@ -1072,7 +1091,6 @@ def initialize_worker_manager_in_client(
if not worker_params.controller_addr:
raise ValueError("Controller can`t be None")
controller_addr = worker_params.controller_addr
logger.info(f"Worker params: {worker_params}")
client = ModelRegistryClient(worker_params.controller_addr)
worker_manager.worker_manager = RemoteWorkerManager(client)

View File

@@ -255,6 +255,10 @@ class BaseEmbeddingModelParameters(BaseModelParameters):
def build_kwargs(self, **kwargs) -> Dict:
pass
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return False
@dataclass
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
@@ -272,6 +276,19 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
max_length: Optional[int] = field(
default=None,
metadata={
"help": "Max length for input sequences. Longer sequences will be "
"truncated. If None, max length of the model will be used, just for rerank"
" model now."
},
)
def build_kwargs(self, **kwargs) -> Dict:
model_kwargs, encode_kwargs = None, None
if self.device:
@@ -280,10 +297,16 @@ class EmbeddingModelParameters(BaseEmbeddingModelParameters):
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
if model_kwargs:
kwargs["model_kwargs"] = model_kwargs
if self.is_rerank_model():
kwargs["max_length"] = self.max_length
if encode_kwargs:
kwargs["encode_kwargs"] = encode_kwargs
return kwargs
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
@dataclass
class ModelParameters(BaseModelParameters):
@@ -537,26 +560,35 @@ class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
)
rerank: Optional[bool] = field(
default=False, metadata={"help": "Whether the model is a rerank model"}
)
def build_kwargs(self, **kwargs) -> Dict:
params = {
"openai_api_base": self.proxy_server_url,
"openai_api_key": self.proxy_api_key,
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
"openai_api_version": self.proxy_api_version
if self.proxy_api_version
else None,
"openai_api_version": (
self.proxy_api_version if self.proxy_api_version else None
),
"model": self.proxy_backend,
"deployment": self.proxy_deployment
if self.proxy_deployment
else self.proxy_backend,
"deployment": (
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
),
}
for k, v in kwargs:
params[k] = v
return params
def is_rerank_model(self) -> bool:
"""Check if the model is a rerank model"""
return self.rerank
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama",
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,"
"proxy_ollama,rerank_proxy_http_openapi",
}
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}