feat: support reranker

Signed-off-by: Anhui-tqhuang <tianqiu.huang@enterprisedb.com>
This commit is contained in:
Anhui-tqhuang 2024-01-23 12:19:52 +08:00
parent 087cb0b7b7
commit 642b75b7e8
No known key found for this signature in database
GPG Key ID: 37B92F5DB83657C7
7 changed files with 1862 additions and 1671 deletions

3404
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,71 @@
from typing import List, Tuple
from injector import singleton, inject
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from llama_index.bridge.pydantic import Field
from FlagEmbedding import FlagReranker
from llama_index.postprocessor.types import BaseNodePostprocessor
from private_gpt.settings.settings import Settings
@singleton
class RerankerComponent(BaseNodePostprocessor):
"""
Reranker component:
- top_n: Top N nodes to return.
- cut_off: Cut off score for nodes.
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
Otherwise, return all nodes with score > cut_off.
"""
reranker: FlagReranker = Field(description="Reranker class.")
top_n: int = Field(description="Top N nodes to return.")
cut_off: float = Field(description="Cut off score for nodes.")
@inject
def __init__(self, settings: Settings) -> None:
if settings.reranker.enabled is False:
raise ValueError("Reranker component is not enabled.")
path = models_path / "reranker"
top_n = settings.reranker.top_n
cut_off = settings.reranker.cut_off
reranker = FlagReranker(
model_name_or_path=path,
)
super().__init__(
top_n=top_n,
reranker=reranker,
cut_off=cut_off,
)
@classmethod
def class_name(cls) -> str:
return "Reranker"
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]:
if query_bundle is None:
return ValueError("Query bundle must be provided.")
query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = []
for node in nodes:
content = node.get_content()
sentence_pairs.append([query_str, content])
scores = self.reranker.compute_score(sentence_pairs)
for i, node in enumerate(nodes):
node.score = scores[i]
# cut off nodes with low scores
res = [node for node in nodes if node.score > self.cut_off]
if len(res) > self.top_n:
return res
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[
: self.top_n
]

View File

@ -18,6 +18,7 @@ from pydantic import BaseModel
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.reranker.reranker import RerankerComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
@ -99,6 +100,8 @@ class ChatService:
embed_model=embedding_component.embedding_model,
show_progress=True,
)
if settings.reranker.enabled:
self.reranker_component = RerankerComponent(settings=settings)
def _chat_engine(
self,
@ -113,16 +116,22 @@ class ChatService:
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)
node_postprocessors = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
]
if self.reranker_component:
node_postprocessors.append(self.reranker_component)
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
llm=self.llm_component.llm, # Takes no effect at the moment
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
),
],
node_postprocessors=node_postprocessors,
)
else:
return SimpleChatEngine.from_defaults(

View File

@ -114,6 +114,25 @@ class NodeStoreSettings(BaseModel):
database: Literal["simple", "postgres"]
class RerankerSettings(BaseModel):
enabled: bool = Field(
False,
description="Flag indicating if reranker is enabled or not",
)
hf_model_name: str = Field(
"BAAI/bge-reranker-large",
description="Name of the HuggingFace model to use for reranking"
)
top_n: int = Field(
5,
description="Top N nodes to return.",
)
cut_off: float = Field(
0.75,
description="Cut off score for nodes.",
)
class LlamaCPPSettings(BaseModel):
llm_hf_repo_id: str
llm_hf_model_file: str
@ -391,6 +410,7 @@ class Settings(BaseModel):
vectorstore: VectorstoreSettings
nodestore: NodeStoreSettings
rag: RagSettings
reranker: RerankerSettings
qdrant: QdrantSettings | None = None
postgres: PostgresSettings | None = None

View File

@ -27,6 +27,17 @@ snapshot_download(
)
print("Embedding model downloaded!")
if settings().reranker.enabled:
# Download Reranker model
reranker_path = models_path / "reranker"
print(f"Downloading reranker {settings().reranker.hf_model_name}")
snapshot_download(
repo_id=settings().reranker.hf_model_name,
cache_dir=models_cache_path,
local_dir=reranker_path,
)
print("Reranker model downloaded!")
# Download LLM and create a symlink to the model file
print(f"Downloading LLM {settings().llamacpp.llm_hf_model_file}")
hf_hub_download(

View File

@ -57,6 +57,12 @@ llamacpp:
top_p: 1.0 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
reranker:
enabled: true
hf_model_name: BAAI/bge-reranker-large
top_n: 5
cut_off: 0.75
embedding:
# Should be matching the value above in most cases
mode: huggingface