mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-05 17:43:51 +00:00
fix: type hionts
This commit is contained in:
parent
dc33bb055a
commit
f60ae69d91
@ -1,4 +1,9 @@
|
|||||||
from FlagEmbedding import FlagReranker
|
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
from FlagEmbedding import FlagReranker # type: ignore
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||||
@ -29,17 +34,13 @@ class RerankerComponent(BaseNodePostprocessor):
|
|||||||
raise ValueError("Reranker component is not enabled.")
|
raise ValueError("Reranker component is not enabled.")
|
||||||
|
|
||||||
path = models_path / "reranker"
|
path = models_path / "reranker"
|
||||||
top_n = settings.reranker.top_n
|
self.top_n = settings.reranker.top_n
|
||||||
cut_off = settings.reranker.cut_off
|
self.cut_off = settings.reranker.cut_off
|
||||||
reranker = FlagReranker(
|
self.reranker = FlagReranker(
|
||||||
model_name_or_path=path,
|
model_name_or_path=path,
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(
|
super().__init__()
|
||||||
top_n=top_n,
|
|
||||||
reranker=reranker,
|
|
||||||
cut_off=cut_off,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
@ -47,24 +48,24 @@ class RerankerComponent(BaseNodePostprocessor):
|
|||||||
|
|
||||||
def _postprocess_nodes(
|
def _postprocess_nodes(
|
||||||
self,
|
self,
|
||||||
nodes: list[NodeWithScore],
|
nodes: List[NodeWithScore], # noqa: UP006
|
||||||
query_bundle: QueryBundle | None = None,
|
query_bundle: QueryBundle | None = None,
|
||||||
) -> list[NodeWithScore]:
|
) -> List[NodeWithScore]: # noqa: UP006
|
||||||
if query_bundle is None:
|
if query_bundle is None:
|
||||||
return ValueError("Query bundle must be provided.")
|
raise ValueError("Query bundle must be provided.")
|
||||||
|
|
||||||
query_str = query_bundle.query_str
|
query_str = query_bundle.query_str
|
||||||
sentence_pairs: list[tuple[str, str]] = []
|
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
content = node.get_content()
|
content = node.get_content()
|
||||||
sentence_pairs.append([query_str, content])
|
sentence_pairs.append((query_str, content))
|
||||||
|
|
||||||
scores = self.reranker.compute_score(sentence_pairs)
|
scores = self.reranker.compute_score(sentence_pairs)
|
||||||
for i, node in enumerate(nodes):
|
for i, node in enumerate(nodes):
|
||||||
node.score = scores[i]
|
node.score = scores[i]
|
||||||
|
|
||||||
# cut off nodes with low scores
|
# cut off nodes with low scores
|
||||||
res = [node for node in nodes if node.score > self.cut_off]
|
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
|
||||||
if len(res) > self.top_n:
|
if len(res) > self.top_n:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from injector import inject, singleton
|
from injector import inject, singleton
|
||||||
@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|||||||
from private_gpt.server.chunks.chunks_service import Chunk
|
from private_gpt.server.chunks.chunks_service import Chunk
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||||
|
|
||||||
|
|
||||||
class Completion(BaseModel):
|
class Completion(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
@ -117,7 +121,7 @@ class ChatService:
|
|||||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
node_postprocessors = [
|
node_postprocessors: list[BaseNodePostprocessor] = [
|
||||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||||
SimilarityPostprocessor(
|
SimilarityPostprocessor(
|
||||||
similarity_cutoff=settings.rag.similarity_value
|
similarity_cutoff=settings.rag.similarity_value
|
||||||
|
Loading…
Reference in New Issue
Block a user