fix: Remove global state (#1216)

* Remove all global settings state

* chore: remove autogenerated class

* chore: cleanup

* chore: merge conflicts
This commit is contained in:
Pablo Orgaz 2023-11-12 22:20:36 +01:00 committed by GitHub
parent f394ca61bb
commit 022bd718e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 286 additions and 190 deletions

View File

@ -8,4 +8,4 @@ from private_gpt.settings.settings import settings
# Set log_config=None to do not use the uvicorn logging configuration, and
# use ours instead. For reference, see below:
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
uvicorn.run(app, host="0.0.0.0", port=settings.server.port, log_config=None)
uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)

View File

@ -3,7 +3,7 @@ from llama_index import MockEmbedding
from llama_index.embeddings.base import BaseEmbedding
from private_gpt.paths import models_cache_path
from private_gpt.settings.settings import settings
from private_gpt.settings.settings import Settings
@singleton
@ -11,7 +11,7 @@ class EmbeddingComponent:
embedding_model: BaseEmbedding
@inject
def __init__(self) -> None:
def __init__(self, settings: Settings) -> None:
match settings.llm.mode:
case "local":
from llama_index.embeddings import HuggingFaceEmbedding

View File

@ -4,7 +4,7 @@ from llama_index.llms.base import LLM
from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt
from private_gpt.paths import models_path
from private_gpt.settings.settings import settings
from private_gpt.settings.settings import Settings
@singleton
@ -12,7 +12,7 @@ class LLMComponent:
llm: LLM
@inject
def __init__(self) -> None:
def __init__(self, settings: Settings) -> None:
match settings.llm.mode:
case "local":
from llama_index.llms import LlamaCPP

View File

@ -1,9 +1,19 @@
from injector import Injector
from private_gpt.settings.settings import Settings, unsafe_typed_settings
def create_application_injector() -> Injector:
injector = Injector(auto_bind=True)
return injector
_injector = Injector(auto_bind=True)
_injector.binder.bind(Settings, to=unsafe_typed_settings)
return _injector
root_injector: Injector = create_application_injector()
"""
Global injector for the application.
Avoid using this reference, it will make your code harder to test.
Instead, use the `request.state.injector` reference, which is bound to every request
"""
global_injector: Injector = create_application_injector()

128
private_gpt/launcher.py Normal file
View File

@ -0,0 +1,128 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from injector import Injector
from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI:
# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()
tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[method-assign]
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)
return app

View File

@ -1,124 +1,11 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from typing import Any
import llama_index
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from private_gpt.paths import docs_path
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.settings.settings import settings
logger = logging.getLogger(__name__)
from private_gpt.di import global_injector
from private_gpt.launcher import create_app
# Add LlamaIndex simple observability
llama_index.set_global_handler("simple")
# Start the API
with open(docs_path / "description.md") as description_file:
description = description_file.read()
tags_metadata = [
{
"name": "Ingestion",
"description": "High-level APIs covering document ingestion -internally "
"managing document parsing, splitting,"
"metadata extraction, embedding generation and storage- and ingested "
"documents CRUD."
"Each ingested document is identified by an ID that can be used to filter the "
"context"
"used in *Contextual Completions* and *Context Chunks* APIs.",
},
{
"name": "Contextual Completions",
"description": "High-level APIs covering contextual Chat and Completions. They "
"follow OpenAI's format, extending it to "
"allow using the context coming from ingested documents to create the "
"response. Internally"
"manage context retrieval, prompt engineering and the response generation.",
},
{
"name": "Context Chunks",
"description": "Low-level API that given a query return relevant chunks of "
"text coming from the ingested"
"documents.",
},
{
"name": "Embeddings",
"description": "Low-level API to obtain the vector representation of a given "
"text, using an Embeddings model."
"Follows OpenAI's embeddings API format.",
},
{
"name": "Health",
"description": "Simple health API to make sure the server is up and running.",
},
]
app = FastAPI()
def custom_openapi() -> dict[str, Any]:
if app.openapi_schema:
return app.openapi_schema
openapi_schema = get_openapi(
title="PrivateGPT",
description=description,
version="0.1.0",
summary="PrivateGPT is a production-ready AI project that allows you to "
"ask questions to your documents using the power of Large Language "
"Models (LLMs), even in scenarios without Internet connection. "
"100% private, no data leaves your execution environment at any point.",
contact={
"url": "https://github.com/imartinez/privateGPT",
},
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
routes=app.routes,
tags=tags_metadata,
)
openapi_schema["info"]["x-logo"] = {
"url": "https://lh3.googleusercontent.com/drive-viewer"
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
}
app.openapi_schema = openapi_schema
return app.openapi_schema
app.openapi = custom_openapi # type: ignore[method-assign]
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
from private_gpt.ui.ui import PrivateGptUi
PrivateGptUi().mount_in_app(app)
app = create_app(global_injector)

View File

@ -13,4 +13,6 @@ def _absolute_or_from_project_root(path: str) -> Path:
models_path: Path = PROJECT_ROOT_PATH / "models"
models_cache_path: Path = models_path / "cache"
docs_path: Path = PROJECT_ROOT_PATH / "docs"
local_data_path: Path = _absolute_or_from_project_root(settings.data.local_data_folder)
local_data_path: Path = _absolute_or_from_project_root(
settings().data.local_data_folder
)

View File

@ -1,9 +1,8 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
@ -52,7 +51,9 @@ class ChatBody(BaseModel):
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response.
If `use_context` is set to `true`, the model will use context coming
@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
"finish_reason":null}]}
```
"""
service = root_injector.get(ChatService)
service = request.state.injector.get(ChatService)
all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
]

View File

@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.utils.auth import authenticated
@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):
@chunks_router.post("/chunks", tags=["Context Chunks"])
def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
"""Given a `text`, returns the most relevant chunks from the ingested documents.
The returned information can be used to generate prompts that can be
@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
`/ingest/list` endpoint. If you want all ingested documents to be used,
remove `context_filter` altogether.
"""
service = root_injector.get(ChunksService)
service = request.state.injector.get(ChunksService)
results = service.retrieve_relevant(
body.text, body.context_filter, body.limit, body.prev_next_chunks
)

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from starlette.responses import StreamingResponse
@ -41,7 +41,9 @@ class CompletionsBody(BaseModel):
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse:
def prompt_completion(
request: Request, body: CompletionsBody
) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context`
@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
include_sources=body.include_sources,
context_filter=body.context_filter,
)
return chat_completion(chat_body)
return chat_completion(request, chat_body)

View File

@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.embeddings.embeddings_service import (
Embedding,
EmbeddingsService,
@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel):
@embeddings_router.post("/embeddings", tags=["Embeddings"])
def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse:
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
"""Get a vector representation of a given input.
That vector representation can be easily consumed
by machine learning models and algorithms.
"""
service = root_injector.get(EmbeddingsService)
service = request.state.injector.get(EmbeddingsService)
input_texts = body.input if isinstance(body.input, list) else [body.input]
embeddings = service.texts_embeddings(input_texts)
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)

View File

@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
from private_gpt.server.utils.auth import authenticated
@ -17,7 +16,7 @@ class IngestResponse(BaseModel):
@ingest_router.post("/ingest", tags=["Ingestion"])
def ingest(file: UploadFile) -> IngestResponse:
def ingest(request: Request, file: UploadFile) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in
@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse:
can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
if file.filename is None:
raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest(file.filename, file.file.read())
@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse:
@ingest_router.get("/ingest/list", tags=["Ingestion"])
def list_ingested() -> IngestResponse:
def list_ingested(request: Request) -> IngestResponse:
"""Lists already ingested Documents including their Document ID and metadata.
Those IDs can be used to filter the context used to create responses
in `/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
ingested_documents = service.list_ingested()
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
def delete_ingested(doc_id: str) -> None:
def delete_ingested(request: Request, doc_id: str) -> None:
"""Delete the specified ingested Document.
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
The document will be effectively deleted from your storage context.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
service.delete(doc_id)

View File

@ -38,13 +38,13 @@ logger = logging.getLogger(__name__)
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
"""Check if the request is authenticated."""
if not secrets.compare_digest(authorization, settings.server.auth.secret):
if not secrets.compare_digest(authorization, settings().server.auth.secret):
# If the "Authorization" header is not the expected one, raise an exception.
raise NOT_AUTHENTICATED
return True
if not settings.server.auth.enabled:
if not settings().server.auth.enabled:
logger.debug(
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
)
@ -62,7 +62,7 @@ else:
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
) -> bool:
"""Check if the request is authenticated."""
assert settings.server.auth.enabled
assert settings().server.auth.enabled
if not _simple_authentication:
raise NOT_AUTHENTICATED
return True

View File

@ -2,7 +2,7 @@ from typing import Literal
from pydantic import BaseModel, Field
from private_gpt.settings.settings_loader import load_active_profiles
from private_gpt.settings.settings_loader import load_active_settings
class CorsSettings(BaseModel):
@ -114,4 +114,29 @@ class Settings(BaseModel):
openai: OpenAISettings
settings = Settings(**load_active_profiles())
"""
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_settings = load_active_settings()
"""
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_typed_settings = Settings(**unsafe_settings)
def settings() -> Settings:
"""Get the current loaded settings from the DI container.
This method exists to keep compatibility with the existing code,
that require global access to the settings.
For regular components use dependency injection instead.
"""
from private_gpt.di import global_injector
return global_injector.get(Settings)

View File

@ -2,6 +2,7 @@ import functools
import logging
import os
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Any
@ -28,7 +29,11 @@ active_profiles: list[str] = unique_list(
)
def load_profile(profile: str) -> dict[str, Any]:
def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
return functools.reduce(deep_update, settings, {})
def load_settings_from_profile(profile: str) -> dict[str, Any]:
if profile == "default":
profile_file_name = "settings.yaml"
else:
@ -42,9 +47,11 @@ def load_profile(profile: str) -> dict[str, Any]:
return config
def load_active_profiles() -> dict[str, Any]:
def load_active_settings() -> dict[str, Any]:
"""Load active profiles and merge them."""
logger.info("Starting application with profiles=%s", active_profiles)
loaded_profiles = [load_profile(profile) for profile in active_profiles]
merged: dict[str, Any] = functools.reduce(deep_update, loaded_profiles, {})
loaded_profiles = [
load_settings_from_profile(profile) for profile in active_profiles
]
merged: dict[str, Any] = merge_settings(loaded_profiles)
return merged

View File

@ -8,10 +8,11 @@ from typing import Any, TextIO
import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.di import global_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
@ -48,11 +49,18 @@ class Source(BaseModel):
return curated_sources
@singleton
class PrivateGptUi:
def __init__(self) -> None:
self._ingest_service = root_injector.get(IngestService)
self._chat_service = root_injector.get(ChatService)
self._chunks_service = root_injector.get(ChunksService)
@inject
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
self._chunks_service = chunks_service
# Cache the UI blocks
self._ui_block = None
@ -198,7 +206,7 @@ class PrivateGptUi:
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings.llm.mode}",
label=f"LLM: {settings().llm.mode}",
show_copy_button=True,
render=False,
avatar_images=(
@ -217,16 +225,15 @@ class PrivateGptUi:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI) -> None:
def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
base_path = settings.ui.path
logger.info("Mounting the gradio UI, at path=%s", base_path)
gr.mount_gradio_app(app, blocks, path=base_path)
logger.info("Mounting the gradio UI, at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path)
if __name__ == "__main__":
ui = PrivateGptUi()
ui = global_injector.get(PrivateGptUi)
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)

View File

@ -2,13 +2,13 @@ import argparse
import logging
from pathlib import Path
from private_gpt.di import root_injector
from private_gpt.di import global_injector
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.ingest.ingest_watcher import IngestWatcher
logger = logging.getLogger(__name__)
ingest_service = root_injector.get(IngestService)
ingest_service = global_injector.get(IngestService)
parser = argparse.ArgumentParser(prog="ingest_folder.py")
parser.add_argument("folder", help="Folder to ingest")

View File

@ -9,9 +9,9 @@ from private_gpt.settings.settings import settings
os.makedirs(models_path, exist_ok=True)
embedding_path = models_path / "embedding"
print(f"Downloading embedding {settings.local.embedding_hf_model_name}")
print(f"Downloading embedding {settings().local.embedding_hf_model_name}")
snapshot_download(
repo_id=settings.local.embedding_hf_model_name,
repo_id=settings().local.embedding_hf_model_name,
cache_dir=models_cache_path,
local_dir=embedding_path,
)
@ -20,8 +20,8 @@ print("Downloading models for local execution...")
# Download LLM and create a symlink to the model file
hf_hub_download(
repo_id=settings.local.llm_hf_repo_id,
filename=settings.local.llm_hf_model_file,
repo_id=settings().local.llm_hf_repo_id,
filename=settings().local.llm_hf_model_file,
cache_dir=models_cache_path,
local_dir=models_path,
)

View File

@ -5,8 +5,12 @@ server:
# Dummy secrets used for tests
secret: "foo bar; dummy secret"
data:
local_data_folder: local_data/tests
llm:
mode: mock
mode: mock
ui:
enabled: false

View File

@ -1,15 +1,14 @@
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from private_gpt.main import app
from private_gpt.launcher import create_app
from tests.fixtures.mock_injector import MockInjector
@pytest.fixture()
def current_test_app() -> FastAPI:
return app
def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
if request is not None and hasattr(request, "param"):
injector.bind_settings(request.param or {})
@pytest.fixture()
def test_client() -> TestClient:
return TestClient(app)
app_under_test = create_app(injector.test_injector)
return TestClient(app_under_test)

View File

@ -1,10 +1,13 @@
from collections.abc import Callable
from typing import Any
from unittest.mock import MagicMock
import pytest
from injector import Provider, ScopeDecorator, singleton
from private_gpt.di import create_application_injector
from private_gpt.settings.settings import Settings, unsafe_settings
from private_gpt.settings.settings_loader import merge_settings
from private_gpt.utils.typing import T
@ -24,6 +27,12 @@ class MockInjector:
self.test_injector.binder.bind(interface, to=mock, scope=scope)
return mock # type: ignore
def bind_settings(self, settings: dict[str, Any]) -> Settings:
merged = merge_settings([unsafe_settings, settings])
new_settings = Settings(**merged)
self.test_injector.binder.bind(Settings, new_settings)
return new_settings
def get(self, interface: type[T]) -> T:
return self.test_injector.get(interface)

View File

@ -8,7 +8,7 @@ NOTE: We are not testing the switch based on the config in
from typing import Annotated
import pytest
from fastapi import Depends, FastAPI
from fastapi import Depends
from fastapi.testclient import TestClient
from private_gpt.server.utils.auth import (
@ -29,15 +29,16 @@ def _copy_simple_authenticated(
@pytest.fixture(autouse=True)
def _patch_authenticated_dependency(current_test_app: FastAPI):
def _patch_authenticated_dependency(test_client: TestClient):
# Patch the server to use simple authentication
current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated
test_client.app.dependency_overrides[authenticated] = _copy_simple_authenticated
# Call the actual test
yield
# Remove the patch for other tests
current_test_app.dependency_overrides = {}
test_client.app.dependency_overrides = {}
def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None:
@ -50,6 +51,6 @@ def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None:
assert response_fail.status_code == 401
response_success = test_client.get(
"/v1/ingest/list", headers={"Authorization": settings.server.auth.secret}
"/v1/ingest/list", headers={"Authorization": settings().server.auth.secret}
)
assert response_success.status_code == 200

View File

@ -1,5 +1,12 @@
from private_gpt.settings.settings import settings
from private_gpt.settings.settings import Settings, settings
from tests.fixtures.mock_injector import MockInjector
def test_settings_are_loaded_and_merged() -> None:
assert settings.server.env_name == "test"
assert settings().server.env_name == "test"
def test_settings_can_be_overriden(injector: MockInjector) -> None:
injector.bind_settings({"server": {"env_name": "overriden"}})
mocked_settings = injector.get(Settings)
assert mocked_settings.server.env_name == "overriden"

10
tests/ui/test_ui.py Normal file
View File

@ -0,0 +1,10 @@
import pytest
from fastapi.testclient import TestClient
@pytest.mark.parametrize(
"test_client", [{"ui": {"enabled": True, "path": "/ui"}}], indirect=True
)
def test_ui_starts_in_the_given_endpoint(test_client: TestClient) -> None:
response = test_client.get("/ui")
assert response.status_code == 200