mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-20 08:23:12 +00:00
fix: reformat
This commit is contained in:
parent
f4c58ceb0b
commit
b652b2ddbc
@ -7,6 +7,7 @@ from FlagEmbedding import FlagReranker
|
|||||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class RerankerComponent(BaseNodePostprocessor):
|
class RerankerComponent(BaseNodePostprocessor):
|
||||||
"""
|
"""
|
||||||
@ -17,6 +18,7 @@ class RerankerComponent(BaseNodePostprocessor):
|
|||||||
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
|
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
|
||||||
Otherwise, return all nodes with score > cut_off.
|
Otherwise, return all nodes with score > cut_off.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reranker: FlagReranker = Field(description="Reranker class.")
|
reranker: FlagReranker = Field(description="Reranker class.")
|
||||||
top_n: int = Field(description="Top N nodes to return.")
|
top_n: int = Field(description="Top N nodes to return.")
|
||||||
cut_off: float = Field(description="Cut off score for nodes.")
|
cut_off: float = Field(description="Cut off score for nodes.")
|
||||||
@ -66,6 +68,4 @@ class RerankerComponent(BaseNodePostprocessor):
|
|||||||
if len(res) > self.top_n:
|
if len(res) > self.top_n:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[
|
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n]
|
||||||
: self.top_n
|
|
||||||
]
|
|
||||||
|
@ -121,7 +121,7 @@ class RerankerSettings(BaseModel):
|
|||||||
)
|
)
|
||||||
hf_model_name: str = Field(
|
hf_model_name: str = Field(
|
||||||
"BAAI/bge-reranker-large",
|
"BAAI/bge-reranker-large",
|
||||||
description="Name of the HuggingFace model to use for reranking"
|
description="Name of the HuggingFace model to use for reranking",
|
||||||
)
|
)
|
||||||
top_n: int = Field(
|
top_n: int = Field(
|
||||||
5,
|
5,
|
||||||
|
Loading…
Reference in New Issue
Block a user