mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Embeddings",
|
||||
@@ -28,4 +29,6 @@ __ALL__ = [
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
]
|
||||
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.core.awel import DAGVar
|
||||
from dbgpt.core.awel.flow import ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
"""
|
||||
|
||||
|
||||
class RerankEmbeddingFactory(BaseComponent, ABC):
|
||||
"""Class for RerankEmbeddingFactory."""
|
||||
|
||||
name = "rerank_embedding_factory"
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> RerankEmbeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
|
||||
Returns:
|
||||
RerankEmbeddings: The embedding instance.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
"""Embedding implementations."""
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
131
dbgpt/rag/embedding/rerank.py
Normal file
131
dbgpt/rag/embedding/rerank.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Re-rank embeddings."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import RerankEmbeddings
|
||||
|
||||
|
||||
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""CrossEncoder Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "BAAI/bge-reranker-base"
|
||||
max_length: Optional[int] = None
|
||||
"""Max length for input sequences. Longer sequences will be truncated. If None, max
|
||||
length of the model will be used"""
|
||||
"""Model name to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install sentence-transformers`",
|
||||
)
|
||||
|
||||
kwargs["client"] = CrossEncoder(
|
||||
kwargs.get("model_name"),
|
||||
max_length=kwargs.get("max_length"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
query_content_pairs = [[query, candidate] for candidate in candidates]
|
||||
_model = cast(CrossEncoder, self.client)
|
||||
rank_scores = _model.predict(sentences=query_content_pairs)
|
||||
return rank_scores.tolist()
|
||||
|
||||
|
||||
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""OpenAPI Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/v1/beta/relevance",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="bge-reranker-base", description="The name of the model to use."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
if "session" not in kwargs: # noqa: SIM401
|
||||
session = requests.Session()
|
||||
else:
|
||||
session = kwargs["session"]
|
||||
api_key = kwargs.get("api_key")
|
||||
if api_key:
|
||||
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
response = self.session.post( # type: ignore
|
||||
self.api_url, json=data, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
if "data" not in response_data:
|
||||
raise RuntimeError(response_data["detail"])
|
||||
return response_data["data"]
|
Reference in New Issue
Block a user