mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-18 15:37:09 +00:00
fix: tests
This commit is contained in:
parent
b652b2ddbc
commit
dc33bb055a
@ -215,7 +215,7 @@ class ChatMLPromptStyle(AbstractPromptStyle):
|
|||||||
|
|
||||||
|
|
||||||
def get_prompt_style(
|
def get_prompt_style(
|
||||||
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None
|
prompt_style: Literal["default", "llama2", "tag", "mistral", "chatml"] | None,
|
||||||
) -> AbstractPromptStyle:
|
) -> AbstractPromptStyle:
|
||||||
"""Get the prompt style to use from the given string.
|
"""Get the prompt style to use from the given string.
|
||||||
|
|
||||||
|
@ -1,17 +1,17 @@
|
|||||||
from typing import List, Tuple
|
|
||||||
from injector import singleton, inject
|
|
||||||
from llama_index.schema import NodeWithScore, QueryBundle
|
|
||||||
from private_gpt.paths import models_path
|
|
||||||
from llama_index.bridge.pydantic import Field
|
|
||||||
from FlagEmbedding import FlagReranker
|
from FlagEmbedding import FlagReranker
|
||||||
|
from injector import inject, singleton
|
||||||
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||||
|
from llama_index.schema import NodeWithScore, QueryBundle
|
||||||
|
|
||||||
|
from private_gpt.paths import models_path
|
||||||
from private_gpt.settings.settings import Settings
|
from private_gpt.settings.settings import Settings
|
||||||
|
|
||||||
|
|
||||||
@singleton
|
@singleton
|
||||||
class RerankerComponent(BaseNodePostprocessor):
|
class RerankerComponent(BaseNodePostprocessor):
|
||||||
"""
|
"""Reranker component.
|
||||||
Reranker component:
|
|
||||||
- top_n: Top N nodes to return.
|
- top_n: Top N nodes to return.
|
||||||
- cut_off: Cut off score for nodes.
|
- cut_off: Cut off score for nodes.
|
||||||
|
|
||||||
@ -47,14 +47,14 @@ class RerankerComponent(BaseNodePostprocessor):
|
|||||||
|
|
||||||
def _postprocess_nodes(
|
def _postprocess_nodes(
|
||||||
self,
|
self,
|
||||||
nodes: List[NodeWithScore],
|
nodes: list[NodeWithScore],
|
||||||
query_bundle: QueryBundle | None = None,
|
query_bundle: QueryBundle | None = None,
|
||||||
) -> List[NodeWithScore]:
|
) -> list[NodeWithScore]:
|
||||||
if query_bundle is None:
|
if query_bundle is None:
|
||||||
return ValueError("Query bundle must be provided.")
|
return ValueError("Query bundle must be provided.")
|
||||||
|
|
||||||
query_str = query_bundle.query_str
|
query_str = query_bundle.query_str
|
||||||
sentence_pairs: List[Tuple[str, str]] = []
|
sentence_pairs: list[tuple[str, str]] = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
content = node.get_content()
|
content = node.get_content()
|
||||||
sentence_pairs.append([query_str, content])
|
sentence_pairs.append([query_str, content])
|
||||||
|
@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
def create_app(root_injector: Injector) -> FastAPI:
|
def create_app(root_injector: Injector) -> FastAPI:
|
||||||
|
|
||||||
# Start the API
|
# Start the API
|
||||||
async def bind_injector_to_request(request: Request) -> None:
|
async def bind_injector_to_request(request: Request) -> None:
|
||||||
request.state.injector = root_injector
|
request.state.injector = root_injector
|
||||||
|
@ -60,7 +60,7 @@ else:
|
|||||||
|
|
||||||
# Method to be used as a dependency to check if the request is authenticated.
|
# Method to be used as a dependency to check if the request is authenticated.
|
||||||
def authenticated(
|
def authenticated(
|
||||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
|
_simple_authentication: Annotated[bool, Depends(_simple_authentication)],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if the request is authenticated."""
|
"""Check if the request is authenticated."""
|
||||||
assert settings().server.auth.enabled
|
assert settings().server.auth.enabled
|
||||||
|
@ -145,7 +145,6 @@ class PrivateGptUi:
|
|||||||
)
|
)
|
||||||
match mode:
|
match mode:
|
||||||
case "Query Files":
|
case "Query Files":
|
||||||
|
|
||||||
# Use only the selected file for the query
|
# Use only the selected file for the query
|
||||||
context_filter = None
|
context_filter = None
|
||||||
if self._selected_filename is not None:
|
if self._selected_filename is not None:
|
||||||
|
@ -14,8 +14,13 @@ qdrant:
|
|||||||
llm:
|
llm:
|
||||||
mode: mock
|
mode: mock
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
embedding:
|
embedding:
|
||||||
mode: mock
|
mode: mock
|
||||||
|
=======
|
||||||
|
reranker:
|
||||||
|
enabled: false
|
||||||
|
>>>>>>> c096818 (fix: tests)
|
||||||
|
|
||||||
ui:
|
ui:
|
||||||
enabled: false
|
enabled: false
|
Loading…
Reference in New Issue
Block a user