fix: tests

This commit is contained in:
Anhui-tqhuang 2024-02-20 23:15:43 +08:00
parent b652b2ddbc
commit dc33bb055a
No known key found for this signature in database
GPG Key ID: 37B92F5DB83657C7
6 changed files with 18 additions and 15 deletions

View File

@ -215,7 +215,7 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style( def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
) -> AbstractPromptStyle: ) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string. """Get the prompt style to use from the given string.

View File

@ -1,17 +1,17 @@
from typing import List, Tuple
from injector import singleton, inject
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from llama_index.bridge.pydantic import Field
from FlagEmbedding import FlagReranker from FlagEmbedding import FlagReranker
from injector import inject, singleton
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings from private_gpt.settings.settings import Settings
@singleton @singleton
class RerankerComponent(BaseNodePostprocessor): class RerankerComponent(BaseNodePostprocessor):
""" """Reranker component.
Reranker component:
- top_n: Top N nodes to return. - top_n: Top N nodes to return.
- cut_off: Cut off score for nodes. - cut_off: Cut off score for nodes.
@ -47,14 +47,14 @@ class RerankerComponent(BaseNodePostprocessor):
def _postprocess_nodes( def _postprocess_nodes(
self, self,
nodes: List[NodeWithScore], nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None, query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]: ) -> list[NodeWithScore]:
if query_bundle is None: if query_bundle is None:
return ValueError("Query bundle must be provided.") return ValueError("Query bundle must be provided.")
query_str = query_bundle.query_str query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = [] sentence_pairs: list[tuple[str, str]] = []
for node in nodes: for node in nodes:
content = node.get_content() content = node.get_content()
sentence_pairs.append([query_str, content]) sentence_pairs.append([query_str, content])

View File

@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI: def create_app(root_injector: Injector) -> FastAPI:
# Start the API # Start the API
async def bind_injector_to_request(request: Request) -> None: async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector request.state.injector = root_injector

View File

@ -60,7 +60,7 @@ else:
# Method to be used as a dependency to check if the request is authenticated. # Method to be used as a dependency to check if the request is authenticated.
def authenticated( def authenticated(
_simple_authentication: Annotated[bool, Depends(_simple_authentication)] _simple_authentication: Annotated[bool, Depends(_simple_authentication)],
) -> bool: ) -> bool:
"""Check if the request is authenticated.""" """Check if the request is authenticated."""
assert settings().server.auth.enabled assert settings().server.auth.enabled

View File

@ -145,7 +145,6 @@ class PrivateGptUi:
) )
match mode: match mode:
case "Query Files": case "Query Files":
# Use only the selected file for the query # Use only the selected file for the query
context_filter = None context_filter = None
if self._selected_filename is not None: if self._selected_filename is not None:

View File

@ -14,8 +14,13 @@ qdrant:
llm: llm:
mode: mock mode: mock
<<<<<<< HEAD
embedding: embedding:
mode: mock mode: mock
=======
reranker:
enabled: false
>>>>>>> c096818 (fix: tests)
ui: ui:
enabled: false enabled: false