DB-GPT/dbgpt/rag/embedding/rerank.py
2024-07-11 19:39:02 +08:00

148 lines
5.4 KiB
Python

"""Re-rank embeddings."""
from typing import Any, Dict, List, Optional, cast
import aiohttp
import numpy as np
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import RerankEmbeddings
from dbgpt.util.tracer import DBGPT_TRACER_SPAN_ID, root_tracer
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
"""CrossEncoder Rerank Embeddings."""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = "BAAI/bge-reranker-base"
max_length: Optional[int] = None
"""Max length for input sequences. Longer sequences will be truncated. If None, max
length of the model will be used"""
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
kwargs["client"] = CrossEncoder(
kwargs.get("model_name", "BAAI/bge-reranker-base"),
max_length=kwargs.get("max_length"), # type: ignore
**(kwargs.get("model_kwargs") or {}),
)
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
from sentence_transformers import CrossEncoder
query_content_pairs = [[query, candidate] for candidate in candidates]
_model = cast(CrossEncoder, self.client)
rank_scores = _model.predict(sentences=query_content_pairs)
if isinstance(rank_scores, np.ndarray):
rank_scores = rank_scores.tolist()
return rank_scores # type: ignore
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
"""OpenAPI Rerank Embeddings."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/v1/beta/relevance",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="bge-reranker-base", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
pass_trace_id: bool = Field(
default=True, description="Whether to pass the trace ID to the API."
)
session: Optional[requests.Session] = None
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
if not candidates:
return []
headers = {}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
data = {"model": self.model_name, "query": query, "documents": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout, headers=headers
)
response.raise_for_status()
return response.json()["data"]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
current_span_id = root_tracer.get_current_span_id()
if self.pass_trace_id and current_span_id:
# Set the trace ID if available
headers[DBGPT_TRACER_SPAN_ID] = current_span_id
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
data = {"model": self.model_name, "query": query, "documents": candidates}
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]