fix: tests

This commit is contained in:
Anhui-tqhuang 2024-02-20 23:15:43 +08:00
parent b652b2ddbc
commit dc33bb055a
No known key found for this signature in database
GPG Key ID: 37B92F5DB83657C7
6 changed files with 18 additions and 15 deletions

View File

@ -215,7 +215,7 @@ class ChatMLPromptStyle(AbstractPromptStyle):
def get_prompt_style(
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.

View File

@ -1,17 +1,17 @@
from typing import List, Tuple
from injector import singleton, inject
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from llama_index.bridge.pydantic import Field
from FlagEmbedding import FlagReranker
from injector import inject, singleton
from llama_index.bridge.pydantic import Field
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.schema import NodeWithScore, QueryBundle
from private_gpt.paths import models_path
from private_gpt.settings.settings import Settings
@singleton
class RerankerComponent(BaseNodePostprocessor):
"""
Reranker component:
"""Reranker component.
- top_n: Top N nodes to return.
- cut_off: Cut off score for nodes.
@ -47,14 +47,14 @@ class RerankerComponent(BaseNodePostprocessor):
def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
nodes: list[NodeWithScore],
query_bundle: QueryBundle | None = None,
) -> List[NodeWithScore]:
) -> list[NodeWithScore]:
if query_bundle is None:
return ValueError("Query bundle must be provided.")
query_str = query_bundle.query_str
sentence_pairs: List[Tuple[str, str]] = []
sentence_pairs: list[tuple[str, str]] = []
for node in nodes:
content = node.get_content()
sentence_pairs.append([query_str, content])

View File

@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI:
# Start the API
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector

View File

@ -60,7 +60,7 @@ else:
# Method to be used as a dependency to check if the request is authenticated.
def authenticated(
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
_simple_authentication: Annotated[bool, Depends(_simple_authentication)],
) -> bool:
"""Check if the request is authenticated."""
assert settings().server.auth.enabled

View File

@ -145,7 +145,6 @@ class PrivateGptUi:
)
match mode:
case "Query Files":
# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:

View File

@ -14,8 +14,13 @@ qdrant:
llm:
mode: mock
<<<<<<< HEAD
embedding:
mode: mock
=======
reranker:
enabled: false
>>>>>>> c096818 (fix: tests)
ui:
enabled: false
enabled: false