feat(model): Support deploy rerank model (#1522)

This commit is contained in:
Fangyin Cheng
2024-05-16 14:50:16 +08:00
committed by GitHub
parent 559affe87d
commit 593e974405
29 changed files with 814 additions and 75 deletions

View File

@@ -15,6 +15,7 @@ from .embeddings import ( # noqa: F401
OllamaEmbeddings,
OpenAPIEmbeddings,
)
from .rerank import CrossEncoderRerankEmbeddings, OpenAPIRerankEmbeddings # noqa: F401
__ALL__ = [
"Embeddings",
@@ -28,4 +29,6 @@ __ALL__ = [
"DefaultEmbeddingFactory",
"EmbeddingFactory",
"WrappedEmbeddingFactory",
"CrossEncoderRerankEmbeddings",
"OpenAPIRerankEmbeddings",
]

View File

@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type
from dbgpt.component import BaseComponent, SystemApp
from dbgpt.core import Embeddings
from dbgpt.core import Embeddings, RerankEmbeddings
from dbgpt.core.awel import DAGVar
from dbgpt.core.awel.flow import ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@@ -34,6 +34,26 @@ class EmbeddingFactory(BaseComponent, ABC):
"""
class RerankEmbeddingFactory(BaseComponent, ABC):
"""Class for RerankEmbeddingFactory."""
name = "rerank_embedding_factory"
@abstractmethod
def create(
self, model_name: Optional[str] = None, embedding_cls: Optional[Type] = None
) -> RerankEmbeddings:
"""Create an embedding instance.
Args:
model_name (str): The model name.
embedding_cls (Type): The embedding class.
Returns:
RerankEmbeddings: The embedding instance.
"""
class DefaultEmbeddingFactory(EmbeddingFactory):
"""The default embedding factory."""

View File

@@ -1,6 +1,5 @@
"""Embedding implementations."""
from typing import Any, Dict, List, Optional
import aiohttp

View File

@@ -0,0 +1,131 @@
"""Re-rank embeddings."""
from typing import Any, Dict, List, Optional, cast
import aiohttp
import requests
from dbgpt._private.pydantic import EXTRA_FORBID, BaseModel, ConfigDict, Field
from dbgpt.core import RerankEmbeddings
class CrossEncoderRerankEmbeddings(BaseModel, RerankEmbeddings):
"""CrossEncoder Rerank Embeddings."""
model_config = ConfigDict(extra=EXTRA_FORBID, protected_namespaces=())
client: Any #: :meta private:
model_name: str = "BAAI/bge-reranker-base"
max_length: Optional[int] = None
"""Max length for input sequences. Longer sequences will be truncated. If None, max
length of the model will be used"""
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Keyword arguments to pass to the model."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
try:
from sentence_transformers import CrossEncoder
except ImportError:
raise ImportError(
"please `pip install sentence-transformers`",
)
kwargs["client"] = CrossEncoder(
kwargs.get("model_name"),
max_length=kwargs.get("max_length"),
**kwargs.get("model_kwargs"),
)
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
from sentence_transformers import CrossEncoder
query_content_pairs = [[query, candidate] for candidate in candidates]
_model = cast(CrossEncoder, self.client)
rank_scores = _model.predict(sentences=query_content_pairs)
return rank_scores.tolist()
class OpenAPIRerankEmbeddings(BaseModel, RerankEmbeddings):
"""OpenAPI Rerank Embeddings."""
model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=())
api_url: str = Field(
default="http://localhost:8100/v1/beta/relevance",
description="The URL of the embeddings API.",
)
api_key: Optional[str] = Field(
default=None, description="The API key for the embeddings API."
)
model_name: str = Field(
default="bge-reranker-base", description="The name of the model to use."
)
timeout: int = Field(
default=60, description="The timeout for the request in seconds."
)
session: Optional[requests.Session] = None
def __init__(self, **kwargs):
"""Initialize the OpenAPIEmbeddings."""
try:
import requests
except ImportError:
raise ValueError(
"The requests python package is not installed. "
"Please install it with `pip install requests`"
)
if "session" not in kwargs: # noqa: SIM401
session = requests.Session()
else:
session = kwargs["session"]
api_key = kwargs.get("api_key")
if api_key:
session.headers.update({"Authorization": f"Bearer {api_key}"})
kwargs["session"] = session
super().__init__(**kwargs)
def predict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates.
Args:
query: The query text.
candidates: The list of candidate texts.
Returns:
List[float]: The rank scores of the candidates.
"""
if not candidates:
return []
data = {"model": self.model_name, "query": query, "documents": candidates}
response = self.session.post( # type: ignore
self.api_url, json=data, timeout=self.timeout
)
response.raise_for_status()
return response.json()["data"]
async def apredict(self, query: str, candidates: List[str]) -> List[float]:
"""Predict the rank scores of the candidates asynchronously."""
headers = {"Authorization": f"Bearer {self.api_key}"}
async with aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
) as session:
data = {"model": self.model_name, "query": query, "documents": candidates}
async with session.post(self.api_url, json=data) as resp:
resp.raise_for_status()
response_data = await resp.json()
if "data" not in response_data:
raise RuntimeError(response_data["detail"])
return response_data["data"]

View File

@@ -1,4 +1,5 @@
"""Embedding retriever."""
from functools import reduce
from typing import Any, Dict, List, Optional, cast
@@ -207,7 +208,7 @@ class EmbeddingRetriever(BaseRetriever):
"rerank_cls": self._rerank.__class__.__name__,
},
):
new_candidates_with_score = self._rerank.rank(
new_candidates_with_score = await self._rerank.arank(
new_candidates_with_score, query
)
return new_candidates_with_score

View File

@@ -1,9 +1,10 @@
"""Rerank module for RAG retriever."""
import asyncio
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from dbgpt.core import Chunk
from dbgpt.core import Chunk, RerankEmbeddings
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
from dbgpt.util.i18n_utils import _
@@ -39,6 +40,24 @@ class Ranker(ABC):
List[Chunk]
"""
async def arank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Return top k chunks after ranker.
Rank algorithm implementation return topk documents by candidates
similarity score
Args:
candidates_with_scores: List[Tuple]
query: Optional[str]
Return:
List[Chunk]
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.rank, candidates_with_scores, query
)
def _filter(self, candidates_with_scores: List) -> List[Chunk]:
"""Filter duplicate candidates documents."""
candidates_with_scores = sorted(
@@ -52,6 +71,18 @@ class Ranker(ABC):
visited_docs.add(candidate_chunk.content)
return new_candidates
def _rerank_with_scores(
self, candidates_with_scores: List[Chunk], rank_scores: List[float]
) -> List[Chunk]:
"""Rerank candidates with scores."""
for candidate, score in zip(candidates_with_scores, rank_scores):
candidate.score = float(score)
new_candidates_with_scores = sorted(
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores
@register_resource(
_("Default Ranker"),
@@ -225,3 +256,59 @@ class CrossEncoderRanker(Ranker):
candidates_with_scores, key=lambda x: x.score, reverse=True
)
return new_candidates_with_scores[: self.topk]
class RerankEmbeddingsRanker(Ranker):
"""Rerank Embeddings Ranker."""
def __init__(
self,
rerank_embeddings: RerankEmbeddings,
topk: int = 4,
rank_fn: Optional[RANK_FUNC] = None,
):
"""Rerank Embeddings rank algorithm implementation."""
self._model = rerank_embeddings
super().__init__(topk, rank_fn)
def rank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Rerank Embeddings rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
if not candidates_with_scores or not query:
return candidates_with_scores
contents = [candidate.content for candidate in candidates_with_scores]
rank_scores = self._model.predict(query, contents)
new_candidates_with_scores = self._rerank_with_scores(
candidates_with_scores, rank_scores
)
return new_candidates_with_scores[: self.topk]
async def arank(
self, candidates_with_scores: List[Chunk], query: Optional[str] = None
) -> List[Chunk]:
"""Rerank Embeddings rank algorithm implementation.
Args:
candidates_with_scores: List[Chunk], candidates with scores
query: Optional[str], query text
Returns:
List[Chunk], reranked candidates
"""
if not candidates_with_scores or not query:
return candidates_with_scores
contents = [candidate.content for candidate in candidates_with_scores]
rank_scores = await self._model.apredict(query, contents)
new_candidates_with_scores = self._rerank_with_scores(
candidates_with_scores, rank_scores
)
return new_candidates_with_scores[: self.topk]