mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-01 23:47:54 +00:00
fix: type hionts
This commit is contained in:
parent
dc33bb055a
commit
f60ae69d91
@ -1,4 +1,9 @@
|
||||
from FlagEmbedding import FlagReranker
|
||||
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from FlagEmbedding import FlagReranker # type: ignore
|
||||
from injector import inject, singleton
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
@ -29,17 +34,13 @@ class RerankerComponent(BaseNodePostprocessor):
|
||||
raise ValueError("Reranker component is not enabled.")
|
||||
|
||||
path = models_path / "reranker"
|
||||
top_n = settings.reranker.top_n
|
||||
cut_off = settings.reranker.cut_off
|
||||
reranker = FlagReranker(
|
||||
self.top_n = settings.reranker.top_n
|
||||
self.cut_off = settings.reranker.cut_off
|
||||
self.reranker = FlagReranker(
|
||||
model_name_or_path=path,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
top_n=top_n,
|
||||
reranker=reranker,
|
||||
cut_off=cut_off,
|
||||
)
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
@ -47,24 +48,24 @@ class RerankerComponent(BaseNodePostprocessor):
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
nodes: List[NodeWithScore], # noqa: UP006
|
||||
query_bundle: QueryBundle | None = None,
|
||||
) -> list[NodeWithScore]:
|
||||
) -> List[NodeWithScore]: # noqa: UP006
|
||||
if query_bundle is None:
|
||||
return ValueError("Query bundle must be provided.")
|
||||
raise ValueError("Query bundle must be provided.")
|
||||
|
||||
query_str = query_bundle.query_str
|
||||
sentence_pairs: list[tuple[str, str]] = []
|
||||
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
|
||||
for node in nodes:
|
||||
content = node.get_content()
|
||||
sentence_pairs.append([query_str, content])
|
||||
sentence_pairs.append((query_str, content))
|
||||
|
||||
scores = self.reranker.compute_score(sentence_pairs)
|
||||
for i, node in enumerate(nodes):
|
||||
node.score = scores[i]
|
||||
|
||||
# cut off nodes with low scores
|
||||
res = [node for node in nodes if node.score > self.cut_off]
|
||||
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
|
||||
if len(res) > self.top_n:
|
||||
return res
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import typing
|
||||
from dataclasses import dataclass
|
||||
|
||||
from injector import inject, singleton
|
||||
@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chunks.chunks_service import Chunk
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
|
||||
class Completion(BaseModel):
|
||||
response: str
|
||||
@ -117,7 +121,7 @@ class ChatService:
|
||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||
)
|
||||
|
||||
node_postprocessors = [
|
||||
node_postprocessors: list[BaseNodePostprocessor] = [
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
SimilarityPostprocessor(
|
||||
similarity_cutoff=settings.rag.similarity_value
|
||||
|
Loading…
Reference in New Issue
Block a user