mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-17 15:06:56 +00:00
fix: tests
This commit is contained in:
parent
b652b2ddbc
commit
dc33bb055a
@ -215,7 +215,7 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
||||
|
||||
|
||||
def get_prompt_style(
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
|
||||
) -> AbstractPromptStyle:
|
||||
"""Get the prompt style to use from the given string.
|
||||
|
||||
|
@ -1,17 +1,17 @@
|
||||
from typing import List, Tuple
|
||||
from injector import singleton, inject
|
||||
from llama_index.schema import NodeWithScore, QueryBundle
|
||||
from private_gpt.paths import models_path
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from FlagEmbedding import FlagReranker
|
||||
from injector import inject, singleton
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.schema import NodeWithScore, QueryBundle
|
||||
|
||||
from private_gpt.paths import models_path
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
|
||||
@singleton
|
||||
class RerankerComponent(BaseNodePostprocessor):
|
||||
"""
|
||||
Reranker component:
|
||||
"""Reranker component.
|
||||
|
||||
- top_n: Top N nodes to return.
|
||||
- cut_off: Cut off score for nodes.
|
||||
|
||||
@ -47,14 +47,14 @@ class RerankerComponent(BaseNodePostprocessor):
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: QueryBundle | None = None,
|
||||
) -> List[NodeWithScore]:
|
||||
) -> list[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
return ValueError("Query bundle must be provided.")
|
||||
|
||||
query_str = query_bundle.query_str
|
||||
sentence_pairs: List[Tuple[str, str]] = []
|
||||
sentence_pairs: list[tuple[str, str]] = []
|
||||
for node in nodes:
|
||||
content = node.get_content()
|
||||
sentence_pairs.append([query_str, content])
|
||||
|
@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(root_injector: Injector) -> FastAPI:
|
||||
|
||||
# Start the API
|
||||
async def bind_injector_to_request(request: Request) -> None:
|
||||
request.state.injector = root_injector
|
||||
|
@ -60,7 +60,7 @@ else:
|
||||
|
||||
# Method to be used as a dependency to check if the request is authenticated.
|
||||
def authenticated(
|
||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
|
||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)],
|
||||
) -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
assert settings().server.auth.enabled
|
||||
|
@ -145,7 +145,6 @@ class PrivateGptUi:
|
||||
)
|
||||
match mode:
|
||||
case "Query Files":
|
||||
|
||||
# Use only the selected file for the query
|
||||
context_filter = None
|
||||
if self._selected_filename is not None:
|
||||
|
@ -14,8 +14,13 @@ qdrant:
|
||||
llm:
|
||||
mode: mock
|
||||
|
||||
<<<<<<< HEAD
|
||||
embedding:
|
||||
mode: mock
|
||||
=======
|
||||
reranker:
|
||||
enabled: false
|
||||
>>>>>>> c096818 (fix: tests)
|
||||
|
||||
ui:
|
||||
enabled: false
|
||||
enabled: false
|
||||
|
Loading…
Reference in New Issue
Block a user