mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 23:01:38 +00:00
148 lines
5.4 KiB
Python
148 lines
5.4 KiB
Python
"""Re-rank embeddings."""
|
|
|
|
from typing import Any, Dict, List, Optional, cast
|
|
|
|
import aiohttp
|
|
import numpy as np
|
|
import requests
|
|
|
|
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
|
from dbgpt.core import RerankEmbeddings
|
|
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
|
|
|
|
|
|
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
|
"""CrossEncoder Rerank Embeddings."""
|
|
|
|
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
|
|
|
client: Any #: :meta private:
|
|
model_name: str = "BAAI/bge-reranker-base"
|
|
max_length: Optional[int] = None
|
|
"""Max length for input sequences. Longer sequences will be truncated. If None, max
|
|
length of the model will be used"""
|
|
"""Model name to use."""
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
"""Keyword arguments to pass to the model."""
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
"""Initialize the sentence_transformer."""
|
|
try:
|
|
from sentence_transformers import CrossEncoder
|
|
except ImportError:
|
|
raise ImportError(
|
|
"please `pip install sentence-transformers`",
|
|
)
|
|
|
|
kwargs["client"] = CrossEncoder(
|
|
kwargs.get("model_name", "BAAI/bge-reranker-base"),
|
|
max_length=kwargs.get("max_length"), # type: ignore
|
|
**(kwargs.get("model_kwargs") or {}),
|
|
)
|
|
super().__init__(**kwargs)
|
|
|
|
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
|
"""Predict the rank scores of the candidates.
|
|
|
|
Args:
|
|
query: The query text.
|
|
candidates: The list of candidate texts.
|
|
|
|
Returns:
|
|
List[float]: The rank scores of the candidates.
|
|
"""
|
|
from sentence_transformers import CrossEncoder
|
|
|
|
query_content_pairs = [[query, candidate] for candidate in candidates]
|
|
_model = cast(CrossEncoder, self.client)
|
|
rank_scores = _model.predict(sentences=query_content_pairs)
|
|
if isinstance(rank_scores, np.ndarray):
|
|
rank_scores = rank_scores.tolist()
|
|
return rank_scores # type: ignore
|
|
|
|
|
|
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
|
"""OpenAPI Rerank Embeddings."""
|
|
|
|
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
|
|
|
api_url: str = Field(
|
|
default="http://localhost:8100/v1/beta/relevance",
|
|
description="The URL of the embeddings API.",
|
|
)
|
|
api_key: Optional[str] = Field(
|
|
default=None, description="The API key for the embeddings API."
|
|
)
|
|
model_name: str = Field(
|
|
default="bge-reranker-base", description="The name of the model to use."
|
|
)
|
|
timeout: int = Field(
|
|
default=60, description="The timeout for the request in seconds."
|
|
)
|
|
pass_trace_id: bool = Field(
|
|
default=True, description="Whether to pass the trace ID to the API."
|
|
)
|
|
|
|
session: Optional[requests.Session] = None
|
|
|
|
def __init__(self, **kwargs):
|
|
"""Initialize the OpenAPIEmbeddings."""
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
raise ValueError(
|
|
"The requests python package is not installed. "
|
|
"Please install it with `pip install requests`"
|
|
)
|
|
if "session" not in kwargs: # noqa: SIM401
|
|
session = requests.Session()
|
|
else:
|
|
session = kwargs["session"]
|
|
api_key = kwargs.get("api_key")
|
|
if api_key:
|
|
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
|
kwargs["session"] = session
|
|
super().__init__(**kwargs)
|
|
|
|
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
|
"""Predict the rank scores of the candidates.
|
|
|
|
Args:
|
|
query: The query text.
|
|
candidates: The list of candidate texts.
|
|
|
|
Returns:
|
|
List[float]: The rank scores of the candidates.
|
|
"""
|
|
if not candidates:
|
|
return []
|
|
headers = {}
|
|
current_span_id = root_tracer.get_current_span_id()
|
|
if self.pass_trace_id and current_span_id:
|
|
# Set the trace ID if available
|
|
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
|
data = {"model": self.model_name, "query": query, "documents": candidates}
|
|
response = self.session.post( # type: ignore
|
|
self.api_url, json=data, timeout=self.timeout, headers=headers
|
|
)
|
|
response.raise_for_status()
|
|
return response.json()["data"]
|
|
|
|
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
|
"""Predict the rank scores of the candidates asynchronously."""
|
|
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
current_span_id = root_tracer.get_current_span_id()
|
|
if self.pass_trace_id and current_span_id:
|
|
# Set the trace ID if available
|
|
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
|
|
async with aiohttp.ClientSession(
|
|
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
) as session:
|
|
data = {"model": self.model_name, "query": query, "documents": candidates}
|
|
async with session.post(self.api_url, json=data) as resp:
|
|
resp.raise_for_status()
|
|
response_data = await resp.json()
|
|
if "data" not in response_data:
|
|
raise RuntimeError(response_data["detail"])
|
|
return response_data["data"]
|