fix: type hionts

This commit is contained in:
Anhui-tqhuang 2024-02-20 23:56:35 +08:00
parent dc33bb055a
commit f60ae69d91
No known key found for this signature in database
GPG Key ID: 37B92F5DB83657C7
2 changed files with 21 additions and 16 deletions

View File

@ -1,4 +1,9 @@
from FlagEmbedding import FlagReranker
from typing import ( # noqa: UP035, we need to keep the consistence with llamaindex
List,
Tuple,
)
from FlagEmbedding import FlagReranker # type: ignore
from injector import inject, singleton
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor
@ -29,17 +34,13 @@ class RerankerComponent(BaseNodePostprocessor):
raise ValueError("Reranker component is not enabled.")
path = models_path / "reranker"
top_n = settings.reranker.top_n
cut_off = settings.reranker.cut_off
reranker = FlagReranker(
self.top_n = settings.reranker.top_n
self.cut_off = settings.reranker.cut_off
self.reranker = FlagReranker(
model_name_or_path=path,
)
super().__init__(
top_n=top_n,
reranker=reranker,
cut_off=cut_off,
)
super().__init__()
@classmethod
def class_name(cls) -> str:
@ -47,24 +48,24 @@ class RerankerComponent(BaseNodePostprocessor):
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
nodes: List[NodeWithScore], # noqa: UP006
query_bundle: QueryBundle | None = None,
) -> list[NodeWithScore]:
) -> List[NodeWithScore]: # noqa: UP006
if query_bundle is None:
return ValueError("Query bundle must be provided.")
raise ValueError("Query bundle must be provided.")
query_str = query_bundle.query_str
sentence_pairs: list[tuple[str, str]] = []
sentence_pairs: List[Tuple[str, str]] = [] # noqa: UP006
for node in nodes:
content = node.get_content()
sentence_pairs.append([query_str, content])
sentence_pairs.append((query_str, content))
scores = self.reranker.compute_score(sentence_pairs)
for i, node in enumerate(nodes):
node.score = scores[i]
# cut off nodes with low scores
res = [node for node in nodes if node.score > self.cut_off]
res = [node for node in nodes if (node.score or 0.0) > self.cut_off]
if len(res) > self.top_n:
return res

View File

@ -1,3 +1,4 @@
import typing
from dataclasses import dataclass
from injector import inject, singleton
@ -26,6 +27,9 @@ from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings
if typing.TYPE_CHECKING:
from llama_index.postprocessor.types import BaseNodePostprocessor
class Completion(BaseModel):
response: str
@ -117,7 +121,7 @@ class ChatService:
similarity_top_k=self.settings.rag.similarity_top_k,
)
node_postprocessors = [
node_postprocessors: list[BaseNodePostprocessor] = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value