feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
OllamaEmbeddings,
OpenAPIEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
__ALL__ = [
"Embeddings",
@@ -28,4 +29,6 @@ __ALL__ = [
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
]

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.core.awel import DAGVar
from dbgpt.core.awel.flow import ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
"""
class RerankEmbeddingFactory(BaseComponent, ABC):
"""Class for RerankEmbeddingFactory."""
name = "rerank_embedding_factory"
@abstractmethod
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> RerankEmbeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
Returns:
RerankEmbeddings: The embedding instance.
"""
class DefaultEmbeddingFactory(EmbeddingFactory):
"""The default embedding factory."""

View File

@@ -1,6 +1,5 @@
"""Embedding implementations."""
from typing import Any, Dict, List, Optional
import aiohttp

View File

@@ -0,0 +1,131 @@
"""Re-rank embeddings."""
from typing import Any, Dict, List, Optional, cast
import aiohttp
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import RerankEmbeddings
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
"""CrossEncoder Rerank Embeddings."""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = "BAAI/bge-reranker-base"
max_length: Optional[int] = None
"""Max length for input sequences. Longer sequences will be truncated. If None, max
length of the model will be used"""
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
kwargs["client"] = CrossEncoder(
kwargs.get("model_name"),
max_length=kwargs.get("max_length"),
**kwargs.get("model_kwargs"),
)
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
from sentence_transformers import CrossEncoder
query_content_pairs = [[query, candidate] for candidate in candidates]
_model = cast(CrossEncoder, self.client)
rank_scores = _model.predict(sentences=query_content_pairs)
return rank_scores.tolist()
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
"""OpenAPI Rerank Embeddings."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/v1/beta/relevance",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="bge-reranker-base", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: Optional[requests.Session] = None
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
if not candidates:
return []
data = {"model": self.model_name, "query": query, "documents": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout
)
response.raise_for_status()
return response.json()["data"]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
data = {"model": self.model_name, "query": query, "documents": candidates}
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]