mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
feat(model): Support deploy rerank model (#1522)
This commit is contained in:
@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
|
||||
OllamaEmbeddings,
|
||||
OpenAPIEmbeddings,
|
||||
)
|
||||
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
|
||||
|
||||
__ALL__ = [
|
||||
"Embeddings",
|
||||
@@ -28,4 +29,6 @@ __ALL__ = [
|
||||
"DefaultEmbeddingFactory",
|
||||
"EmbeddingFactory",
|
||||
"WrappedEmbeddingFactory",
|
||||
"CrossEncoderRerankEmbeddings",
|
||||
"OpenAPIRerankEmbeddings",
|
||||
]
|
||||
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Type
|
||||
|
||||
from dbgpt.component import BaseComponent, SystemApp
|
||||
from dbgpt.core import Embeddings
|
||||
from dbgpt.core import Embeddings, RerankEmbeddings
|
||||
from dbgpt.core.awel import DAGVar
|
||||
from dbgpt.core.awel.flow import ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
@@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
|
||||
"""
|
||||
|
||||
|
||||
class RerankEmbeddingFactory(BaseComponent, ABC):
|
||||
"""Class for RerankEmbeddingFactory."""
|
||||
|
||||
name = "rerank_embedding_factory"
|
||||
|
||||
@abstractmethod
|
||||
def create(
|
||||
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
|
||||
) -> RerankEmbeddings:
|
||||
"""Create an embedding instance.
|
||||
|
||||
Args:
|
||||
model_name (str): The model name.
|
||||
embedding_cls (Type): The embedding class.
|
||||
|
||||
Returns:
|
||||
RerankEmbeddings: The embedding instance.
|
||||
"""
|
||||
|
||||
|
||||
class DefaultEmbeddingFactory(EmbeddingFactory):
|
||||
"""The default embedding factory."""
|
||||
|
||||
|
@@ -1,6 +1,5 @@
|
||||
"""Embedding implementations."""
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
131
dbgpt/rag/embedding/rerank.py
Normal file
131
dbgpt/rag/embedding/rerank.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Re-rank embeddings."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
|
||||
from dbgpt.core import RerankEmbeddings
|
||||
|
||||
|
||||
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""CrossEncoder Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "BAAI/bge-reranker-base"
|
||||
max_length: Optional[int] = None
|
||||
"""Max length for input sequences. Longer sequences will be truncated. If None, max
|
||||
length of the model will be used"""
|
||||
"""Model name to use."""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the model."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
"""Initialize the sentence_transformer."""
|
||||
try:
|
||||
from sentence_transformers import CrossEncoder
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"please `pip install sentence-transformers`",
|
||||
)
|
||||
|
||||
kwargs["client"] = CrossEncoder(
|
||||
kwargs.get("model_name"),
|
||||
max_length=kwargs.get("max_length"),
|
||||
**kwargs.get("model_kwargs"),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
query_content_pairs = [[query, candidate] for candidate in candidates]
|
||||
_model = cast(CrossEncoder, self.client)
|
||||
rank_scores = _model.predict(sentences=query_content_pairs)
|
||||
return rank_scores.tolist()
|
||||
|
||||
|
||||
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
|
||||
"""OpenAPI Rerank Embeddings."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
|
||||
|
||||
api_url: str = Field(
|
||||
default="http://localhost:8100/v1/beta/relevance",
|
||||
description="The URL of the embeddings API.",
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None, description="The API key for the embeddings API."
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="bge-reranker-base", description="The name of the model to use."
|
||||
)
|
||||
timeout: int = Field(
|
||||
default=60, description="The timeout for the request in seconds."
|
||||
)
|
||||
|
||||
session: Optional[requests.Session] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize the OpenAPIEmbeddings."""
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"The requests python package is not installed. "
|
||||
"Please install it with `pip install requests`"
|
||||
)
|
||||
if "session" not in kwargs: # noqa: SIM401
|
||||
session = requests.Session()
|
||||
else:
|
||||
session = kwargs["session"]
|
||||
api_key = kwargs.get("api_key")
|
||||
if api_key:
|
||||
session.headers.update({"Authorization": f"Bearer {api_key}"})
|
||||
kwargs["session"] = session
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def predict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates.
|
||||
|
||||
Args:
|
||||
query: The query text.
|
||||
candidates: The list of candidate texts.
|
||||
|
||||
Returns:
|
||||
List[float]: The rank scores of the candidates.
|
||||
"""
|
||||
if not candidates:
|
||||
return []
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
response = self.session.post( # type: ignore
|
||||
self.api_url, json=data, timeout=self.timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["data"]
|
||||
|
||||
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
|
||||
"""Predict the rank scores of the candidates asynchronously."""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
async with aiohttp.ClientSession(
|
||||
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
|
||||
) as session:
|
||||
data = {"model": self.model_name, "query": query, "documents": candidates}
|
||||
async with session.post(self.api_url, json=data) as resp:
|
||||
resp.raise_for_status()
|
||||
response_data = await resp.json()
|
||||
if "data" not in response_data:
|
||||
raise RuntimeError(response_data["detail"])
|
||||
return response_data["data"]
|
@@ -1,4 +1,5 @@
|
||||
"""Embedding retriever."""
|
||||
|
||||
from functools import reduce
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
@@ -207,7 +208,7 @@ class EmbeddingRetriever(BaseRetriever):
|
||||
"rerank_cls": self._rerank.__class__.__name__,
|
||||
},
|
||||
):
|
||||
new_candidates_with_score = self._rerank.rank(
|
||||
new_candidates_with_score = await self._rerank.arank(
|
||||
new_candidates_with_score, query
|
||||
)
|
||||
return new_candidates_with_score
|
||||
|
@@ -1,9 +1,10 @@
|
||||
"""Rerank module for RAG retriever."""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from dbgpt.core import Chunk
|
||||
from dbgpt.core import Chunk, RerankEmbeddings
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
@@ -39,6 +40,24 @@ class Ranker(ABC):
|
||||
List[Chunk]
|
||||
"""
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Return top k chunks after ranker.
|
||||
|
||||
Rank algorithm implementation return topk documents by candidates
|
||||
similarity score
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Tuple]
|
||||
query: Optional[str]
|
||||
Return:
|
||||
List[Chunk]
|
||||
"""
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, self.rank, candidates_with_scores, query
|
||||
)
|
||||
|
||||
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
|
||||
"""Filter duplicate candidates documents."""
|
||||
candidates_with_scores = sorted(
|
||||
@@ -52,6 +71,18 @@ class Ranker(ABC):
|
||||
visited_docs.add(candidate_chunk.content)
|
||||
return new_candidates
|
||||
|
||||
def _rerank_with_scores(
|
||||
self, candidates_with_scores: List[Chunk], rank_scores: List[float]
|
||||
) -> List[Chunk]:
|
||||
"""Rerank candidates with scores."""
|
||||
for candidate, score in zip(candidates_with_scores, rank_scores):
|
||||
candidate.score = float(score)
|
||||
|
||||
new_candidates_with_scores = sorted(
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores
|
||||
|
||||
|
||||
@register_resource(
|
||||
_("Default Ranker"),
|
||||
@@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker):
|
||||
candidates_with_scores, key=lambda x: x.score, reverse=True
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
|
||||
class RerankEmbeddingsRanker(Ranker):
|
||||
"""Rerank Embeddings Ranker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rerank_embeddings: RerankEmbeddings,
|
||||
topk: int = 4,
|
||||
rank_fn: Optional[RANK_FUNC] = None,
|
||||
):
|
||||
"""Rerank Embeddings rank algorithm implementation."""
|
||||
self._model = rerank_embeddings
|
||||
super().__init__(topk, rank_fn)
|
||||
|
||||
def rank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = self._model.predict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
||||
async def arank(
|
||||
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
|
||||
) -> List[Chunk]:
|
||||
"""Rerank Embeddings rank algorithm implementation.
|
||||
|
||||
Args:
|
||||
candidates_with_scores: List[Chunk], candidates with scores
|
||||
query: Optional[str], query text
|
||||
Returns:
|
||||
List[Chunk], reranked candidates
|
||||
"""
|
||||
if not candidates_with_scores or not query:
|
||||
return candidates_with_scores
|
||||
|
||||
contents = [candidate.content for candidate in candidates_with_scores]
|
||||
rank_scores = await self._model.apredict(query, contents)
|
||||
new_candidates_with_scores = self._rerank_with_scores(
|
||||
candidates_with_scores, rank_scores
|
||||
)
|
||||
return new_candidates_with_scores[: self.topk]
|
||||
|
Reference in New Issue
Block a user