mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-30 17:22:43 +00:00
feat: support reranker
Signed-off-by: Anhui-tqhuang <tianqiu.huang@enterprisedb.com>
This commit is contained in:
parent
087cb0b7b7
commit
642b75b7e8
3404
poetry.lock
generated
3404
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
0
private_gpt/components/reranker/__init__.py
Normal file
0
private_gpt/components/reranker/__init__.py
Normal file
71
private_gpt/components/reranker/reranker.py
Normal file
71
private_gpt/components/reranker/reranker.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from typing import List, Tuple
|
||||||
|
from injector import singleton, inject
|
||||||
|
from llama_index.schema import NodeWithScore, QueryBundle
|
||||||
|
from private_gpt.paths import models_path
|
||||||
|
from llama_index.bridge.pydantic import Field
|
||||||
|
from FlagEmbedding import FlagReranker
|
||||||
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||||
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
@singleton
|
||||||
|
class RerankerComponent(BaseNodePostprocessor):
|
||||||
|
"""
|
||||||
|
Reranker component:
|
||||||
|
- top_n: Top N nodes to return.
|
||||||
|
- cut_off: Cut off score for nodes.
|
||||||
|
|
||||||
|
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
|
||||||
|
Otherwise, return all nodes with score > cut_off.
|
||||||
|
"""
|
||||||
|
reranker: FlagReranker = Field(description="Reranker class.")
|
||||||
|
top_n: int = Field(description="Top N nodes to return.")
|
||||||
|
cut_off: float = Field(description="Cut off score for nodes.")
|
||||||
|
|
||||||
|
@inject
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
if settings.reranker.enabled is False:
|
||||||
|
raise ValueError("Reranker component is not enabled.")
|
||||||
|
|
||||||
|
path = models_path / "reranker"
|
||||||
|
top_n = settings.reranker.top_n
|
||||||
|
cut_off = settings.reranker.cut_off
|
||||||
|
reranker = FlagReranker(
|
||||||
|
model_name_or_path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
top_n=top_n,
|
||||||
|
reranker=reranker,
|
||||||
|
cut_off=cut_off,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "Reranker"
|
||||||
|
|
||||||
|
def _postprocess_nodes(
|
||||||
|
self,
|
||||||
|
nodes: List[NodeWithScore],
|
||||||
|
query_bundle: QueryBundle | None = None,
|
||||||
|
) -> List[NodeWithScore]:
|
||||||
|
if query_bundle is None:
|
||||||
|
return ValueError("Query bundle must be provided.")
|
||||||
|
|
||||||
|
query_str = query_bundle.query_str
|
||||||
|
sentence_pairs: List[Tuple[str, str]] = []
|
||||||
|
for node in nodes:
|
||||||
|
content = node.get_content()
|
||||||
|
sentence_pairs.append([query_str, content])
|
||||||
|
|
||||||
|
scores = self.reranker.compute_score(sentence_pairs)
|
||||||
|
for i, node in enumerate(nodes):
|
||||||
|
node.score = scores[i]
|
||||||
|
|
||||||
|
# cut off nodes with low scores
|
||||||
|
res = [node for node in nodes if node.score > self.cut_off]
|
||||||
|
if len(res) > self.top_n:
|
||||||
|
return res
|
||||||
|
|
||||||
|
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[
|
||||||
|
: self.top_n
|
||||||
|
]
|
@ -18,6 +18,7 @@ from pydantic import BaseModel
|
|||||||
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
||||||
from private_gpt.components.llm.llm_component import LLMComponent
|
from private_gpt.components.llm.llm_component import LLMComponent
|
||||||
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
||||||
|
from private_gpt.components.reranker.reranker import RerankerComponent
|
||||||
from private_gpt.components.vector_store.vector_store_component import (
|
from private_gpt.components.vector_store.vector_store_component import (
|
||||||
VectorStoreComponent,
|
VectorStoreComponent,
|
||||||
)
|
)
|
||||||
@ -99,6 +100,8 @@ class ChatService:
|
|||||||
embed_model=embedding_component.embedding_model,
|
embed_model=embedding_component.embedding_model,
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
)
|
)
|
||||||
|
if settings.reranker.enabled:
|
||||||
|
self.reranker_component = RerankerComponent(settings=settings)
|
||||||
|
|
||||||
def _chat_engine(
|
def _chat_engine(
|
||||||
self,
|
self,
|
||||||
@ -113,16 +116,22 @@ class ChatService:
|
|||||||
context_filter=context_filter,
|
context_filter=context_filter,
|
||||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||||
)
|
)
|
||||||
return ContextChatEngine.from_defaults(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
retriever=vector_index_retriever,
|
|
||||||
llm=self.llm_component.llm, # Takes no effect at the moment
|
|
||||||
node_postprocessors = [
|
node_postprocessors = [
|
||||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||||
SimilarityPostprocessor(
|
SimilarityPostprocessor(
|
||||||
similarity_cutoff=settings.rag.similarity_value
|
similarity_cutoff=settings.rag.similarity_value
|
||||||
),
|
),
|
||||||
],
|
]
|
||||||
|
|
||||||
|
if self.reranker_component:
|
||||||
|
node_postprocessors.append(self.reranker_component)
|
||||||
|
|
||||||
|
return ContextChatEngine.from_defaults(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
retriever=vector_index_retriever,
|
||||||
|
llm=self.llm_component.llm, # Takes no effect at the moment
|
||||||
|
node_postprocessors=node_postprocessors,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return SimpleChatEngine.from_defaults(
|
return SimpleChatEngine.from_defaults(
|
||||||
|
@ -114,6 +114,25 @@ class NodeStoreSettings(BaseModel):
|
|||||||
database: Literal["simple", "postgres"]
|
database: Literal["simple", "postgres"]
|
||||||
|
|
||||||
|
|
||||||
|
class RerankerSettings(BaseModel):
|
||||||
|
enabled: bool = Field(
|
||||||
|
False,
|
||||||
|
description="Flag indicating if reranker is enabled or not",
|
||||||
|
)
|
||||||
|
hf_model_name: str = Field(
|
||||||
|
"BAAI/bge-reranker-large",
|
||||||
|
description="Name of the HuggingFace model to use for reranking"
|
||||||
|
)
|
||||||
|
top_n: int = Field(
|
||||||
|
5,
|
||||||
|
description="Top N nodes to return.",
|
||||||
|
)
|
||||||
|
cut_off: float = Field(
|
||||||
|
0.75,
|
||||||
|
description="Cut off score for nodes.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LlamaCPPSettings(BaseModel):
|
class LlamaCPPSettings(BaseModel):
|
||||||
llm_hf_repo_id: str
|
llm_hf_repo_id: str
|
||||||
llm_hf_model_file: str
|
llm_hf_model_file: str
|
||||||
@ -391,6 +410,7 @@ class Settings(BaseModel):
|
|||||||
vectorstore: VectorstoreSettings
|
vectorstore: VectorstoreSettings
|
||||||
nodestore: NodeStoreSettings
|
nodestore: NodeStoreSettings
|
||||||
rag: RagSettings
|
rag: RagSettings
|
||||||
|
reranker: RerankerSettings
|
||||||
qdrant: QdrantSettings | None = None
|
qdrant: QdrantSettings | None = None
|
||||||
postgres: PostgresSettings | None = None
|
postgres: PostgresSettings | None = None
|
||||||
|
|
||||||
|
@ -27,6 +27,17 @@ snapshot_download(
|
|||||||
)
|
)
|
||||||
print("Embedding model downloaded!")
|
print("Embedding model downloaded!")
|
||||||
|
|
||||||
|
if settings().reranker.enabled:
|
||||||
|
# Download Reranker model
|
||||||
|
reranker_path = models_path / "reranker"
|
||||||
|
print(f"Downloading reranker {settings().reranker.hf_model_name}")
|
||||||
|
snapshot_download(
|
||||||
|
repo_id=settings().reranker.hf_model_name,
|
||||||
|
cache_dir=models_cache_path,
|
||||||
|
local_dir=reranker_path,
|
||||||
|
)
|
||||||
|
print("Reranker model downloaded!")
|
||||||
|
|
||||||
# Download LLM and create a symlink to the model file
|
# Download LLM and create a symlink to the model file
|
||||||
print(f"Downloading LLM {settings().llamacpp.llm_hf_model_file}")
|
print(f"Downloading LLM {settings().llamacpp.llm_hf_model_file}")
|
||||||
hf_hub_download(
|
hf_hub_download(
|
||||||
|
@ -57,6 +57,12 @@ llamacpp:
|
|||||||
top_p: 1.0 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
top_p: 1.0 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||||
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
|
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
|
||||||
|
|
||||||
|
reranker:
|
||||||
|
enabled: true
|
||||||
|
hf_model_name: BAAI/bge-reranker-large
|
||||||
|
top_n: 5
|
||||||
|
cut_off: 0.75
|
||||||
|
|
||||||
embedding:
|
embedding:
|
||||||
# Should be matching the value above in most cases
|
# Should be matching the value above in most cases
|
||||||
mode: huggingface
|
mode: huggingface
|
||||||
|
Loading…
Reference in New Issue
Block a user