mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-29 16:58:00 +00:00
feat: support reranker
Signed-off-by: Anhui-tqhuang <tianqiu.huang@enterprisedb.com>
This commit is contained in:
parent
087cb0b7b7
commit
642b75b7e8
3404
poetry.lock
generated
3404
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
0
private_gpt/components/reranker/__init__.py
Normal file
0
private_gpt/components/reranker/__init__.py
Normal file
71
private_gpt/components/reranker/reranker.py
Normal file
71
private_gpt/components/reranker/reranker.py
Normal file
@ -0,0 +1,71 @@
|
||||
from typing import List, Tuple
|
||||
from injector import singleton, inject
|
||||
from llama_index.schema import NodeWithScore, QueryBundle
|
||||
from private_gpt.paths import models_path
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from FlagEmbedding import FlagReranker
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
@singleton
|
||||
class RerankerComponent(BaseNodePostprocessor):
|
||||
"""
|
||||
Reranker component:
|
||||
- top_n: Top N nodes to return.
|
||||
- cut_off: Cut off score for nodes.
|
||||
|
||||
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
|
||||
Otherwise, return all nodes with score > cut_off.
|
||||
"""
|
||||
reranker: FlagReranker = Field(description="Reranker class.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
cut_off: float = Field(description="Cut off score for nodes.")
|
||||
|
||||
@inject
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
if settings.reranker.enabled is False:
|
||||
raise ValueError("Reranker component is not enabled.")
|
||||
|
||||
path = models_path / "reranker"
|
||||
top_n = settings.reranker.top_n
|
||||
cut_off = settings.reranker.cut_off
|
||||
reranker = FlagReranker(
|
||||
model_name_or_path=path,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
top_n=top_n,
|
||||
reranker=reranker,
|
||||
cut_off=cut_off,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "Reranker"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: QueryBundle | None = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
return ValueError("Query bundle must be provided.")
|
||||
|
||||
query_str = query_bundle.query_str
|
||||
sentence_pairs: List[Tuple[str, str]] = []
|
||||
for node in nodes:
|
||||
content = node.get_content()
|
||||
sentence_pairs.append([query_str, content])
|
||||
|
||||
scores = self.reranker.compute_score(sentence_pairs)
|
||||
for i, node in enumerate(nodes):
|
||||
node.score = scores[i]
|
||||
|
||||
# cut off nodes with low scores
|
||||
res = [node for node in nodes if node.score > self.cut_off]
|
||||
if len(res) > self.top_n:
|
||||
return res
|
||||
|
||||
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[
|
||||
: self.top_n
|
||||
]
|
@ -18,6 +18,7 @@ from pydantic import BaseModel
|
||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||
from private_gpt.components.llm.llm_component import LLMComponent
|
||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||
from private_gpt.components.reranker.reranker import RerankerComponent
|
||||
from private_gpt.components.vector_store.vector_store_component import (
|
||||
VectorStoreComponent,
|
||||
)
|
||||
@ -99,6 +100,8 @@ class ChatService:
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
if settings.reranker.enabled:
|
||||
self.reranker_component = RerankerComponent(settings=settings)
|
||||
|
||||
def _chat_engine(
|
||||
self,
|
||||
@ -113,16 +116,22 @@ class ChatService:
|
||||
context_filter=context_filter,
|
||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||
)
|
||||
|
||||
node_postprocessors = [
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
SimilarityPostprocessor(
|
||||
similarity_cutoff=settings.rag.similarity_value
|
||||
),
|
||||
]
|
||||
|
||||
if self.reranker_component:
|
||||
node_postprocessors.append(self.reranker_component)
|
||||
|
||||
return ContextChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
retriever=vector_index_retriever,
|
||||
llm=self.llm_component.llm, # Takes no effect at the moment
|
||||
node_postprocessors=[
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
SimilarityPostprocessor(
|
||||
similarity_cutoff=settings.rag.similarity_value
|
||||
),
|
||||
],
|
||||
node_postprocessors=node_postprocessors,
|
||||
)
|
||||
else:
|
||||
return SimpleChatEngine.from_defaults(
|
||||
|
@ -114,6 +114,25 @@ class NodeStoreSettings(BaseModel):
|
||||
database: Literal["simple", "postgres"]
|
||||
|
||||
|
||||
class RerankerSettings(BaseModel):
|
||||
enabled: bool = Field(
|
||||
False,
|
||||
description="Flag indicating if reranker is enabled or not",
|
||||
)
|
||||
hf_model_name: str = Field(
|
||||
"BAAI/bge-reranker-large",
|
||||
description="Name of the HuggingFace model to use for reranking"
|
||||
)
|
||||
top_n: int = Field(
|
||||
5,
|
||||
description="Top N nodes to return.",
|
||||
)
|
||||
cut_off: float = Field(
|
||||
0.75,
|
||||
description="Cut off score for nodes.",
|
||||
)
|
||||
|
||||
|
||||
class LlamaCPPSettings(BaseModel):
|
||||
llm_hf_repo_id: str
|
||||
llm_hf_model_file: str
|
||||
@ -391,6 +410,7 @@ class Settings(BaseModel):
|
||||
vectorstore: VectorstoreSettings
|
||||
nodestore: NodeStoreSettings
|
||||
rag: RagSettings
|
||||
reranker: RerankerSettings
|
||||
qdrant: QdrantSettings | None = None
|
||||
postgres: PostgresSettings | None = None
|
||||
|
||||
|
@ -27,6 +27,17 @@ snapshot_download(
|
||||
)
|
||||
print("Embedding model downloaded!")
|
||||
|
||||
if settings().reranker.enabled:
|
||||
# Download Reranker model
|
||||
reranker_path = models_path / "reranker"
|
||||
print(f"Downloading reranker {settings().reranker.hf_model_name}")
|
||||
snapshot_download(
|
||||
repo_id=settings().reranker.hf_model_name,
|
||||
cache_dir=models_cache_path,
|
||||
local_dir=reranker_path,
|
||||
)
|
||||
print("Reranker model downloaded!")
|
||||
|
||||
# Download LLM and create a symlink to the model file
|
||||
print(f"Downloading LLM {settings().llamacpp.llm_hf_model_file}")
|
||||
hf_hub_download(
|
||||
|
@ -57,6 +57,12 @@ llamacpp:
|
||||
top_p: 1.0 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
|
||||
|
||||
reranker:
|
||||
enabled: true
|
||||
hf_model_name: BAAI/bge-reranker-large
|
||||
top_n: 5
|
||||
cut_off: 0.75
|
||||
|
||||
embedding:
|
||||
# Should be matching the value above in most cases
|
||||
mode: huggingface
|
||||
|
Loading…
Reference in New Issue
Block a user