fix: reformat

This commit is contained in:
Anhui-tqhuang 2024-02-20 22:31:09 +08:00
parent f4c58ceb0b
commit b652b2ddbc
No known key found for this signature in database
GPG Key ID: 37B92F5DB83657C7
2 changed files with 4 additions and 4 deletions

View File

@ -7,6 +7,7 @@ from FlagEmbedding import FlagReranker
from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.postprocessor.types import BaseNodePostprocessor
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
@singleton @singleton
class RerankerComponent(BaseNodePostprocessor): class RerankerComponent(BaseNodePostprocessor):
""" """
@ -17,6 +18,7 @@ class RerankerComponent(BaseNodePostprocessor):
If the number of nodes with score > cut_off is <= top_n, then return top_n nodes. If the number of nodes with score > cut_off is <= top_n, then return top_n nodes.
Otherwise, return all nodes with score > cut_off. Otherwise, return all nodes with score > cut_off.
""" """
reranker: FlagReranker = Field(description="Reranker class.") reranker: FlagReranker = Field(description="Reranker class.")
top_n: int = Field(description="Top N nodes to return.") top_n: int = Field(description="Top N nodes to return.")
cut_off: float = Field(description="Cut off score for nodes.") cut_off: float = Field(description="Cut off score for nodes.")
@ -66,6 +68,4 @@ class RerankerComponent(BaseNodePostprocessor):
if len(res) > self.top_n: if len(res) > self.top_n:
return res return res
return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[ return sorted(nodes, key=lambda x: x.score or 0.0, reverse=True)[: self.top_n]
: self.top_n
]

View File

@ -121,7 +121,7 @@ class RerankerSettings(BaseModel):
) )
hf_model_name: str = Field( hf_model_name: str = Field(
"BAAI/bge-reranker-large", "BAAI/bge-reranker-large",
description="Name of the HuggingFace model to use for reranking" description="Name of the HuggingFace model to use for reranking",
) )
top_n: int = Field( top_n: int = Field(
5, 5,