From 022bd718e3dfc197027b1e24fb97e5525b186db4 Mon Sep 17 00:00:00 2001 From: Pablo Orgaz <pabloogc@gmail.com> Date: Sun, 12 Nov 2023 22:20:36 +0100 Subject: [PATCH] fix: Remove global state (#1216) * Remove all global settings state * chore: remove autogenerated class * chore: cleanup * chore: merge conflicts --- private_gpt/__main__.py | 2 +- .../embedding/embedding_component.py | 4 +- private_gpt/components/llm/llm_component.py | 4 +- private_gpt/di.py | 16 ++- private_gpt/launcher.py | 128 ++++++++++++++++++ private_gpt/main.py | 119 +--------------- private_gpt/paths.py | 4 +- private_gpt/server/chat/chat_router.py | 9 +- private_gpt/server/chunks/chunks_router.py | 7 +- .../server/completions/completions_router.py | 8 +- .../server/embeddings/embeddings_router.py | 7 +- private_gpt/server/ingest/ingest_router.py | 15 +- private_gpt/server/utils/auth.py | 6 +- private_gpt/settings/settings.py | 29 +++- private_gpt/settings/settings_loader.py | 15 +- private_gpt/ui/ui.py | 29 ++-- scripts/ingest_folder.py | 4 +- scripts/setup | 8 +- settings-test.yaml | 6 +- tests/fixtures/fast_api_test_client.py | 15 +- tests/fixtures/mock_injector.py | 9 ++ tests/server/utils/test_simple_auth.py | 11 +- tests/settings/test_settings.py | 11 +- tests/ui/test_ui.py | 10 ++ 24 files changed, 286 insertions(+), 190 deletions(-) create mode 100644 private_gpt/launcher.py create mode 100644 tests/ui/test_ui.py diff --git a/private_gpt/__main__.py b/private_gpt/__main__.py index 6bf2f156..18b42fd7 100644 --- a/private_gpt/__main__.py +++ b/private_gpt/__main__.py @@ -8,4 +8,4 @@ from private_gpt.settings.settings import settings # Set log_config=None to do not use the uvicorn logging configuration, and # use ours instead. For reference, see below: # https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108 -uvicorn.run(app, host="0.0.0.0", port=settings.server.port, log_config=None) +uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None) diff --git a/private_gpt/components/embedding/embedding_component.py b/private_gpt/components/embedding/embedding_component.py index f71be0a6..53fc984e 100644 --- a/private_gpt/components/embedding/embedding_component.py +++ b/private_gpt/components/embedding/embedding_component.py @@ -3,7 +3,7 @@ from llama_index import MockEmbedding from llama_index.embeddings.base import BaseEmbedding from private_gpt.paths import models_cache_path -from private_gpt.settings.settings import settings +from private_gpt.settings.settings import Settings @singleton @@ -11,7 +11,7 @@ class EmbeddingComponent: embedding_model: BaseEmbedding @inject - def __init__(self) -> None: + def __init__(self, settings: Settings) -> None: match settings.llm.mode: case "local": from llama_index.embeddings import HuggingFaceEmbedding diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index cad6ed67..4f46f151 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -4,7 +4,7 @@ from llama_index.llms.base import LLM from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt from private_gpt.paths import models_path -from private_gpt.settings.settings import settings +from private_gpt.settings.settings import Settings @singleton @@ -12,7 +12,7 @@ class LLMComponent: llm: LLM @inject - def __init__(self) -> None: + def __init__(self, settings: Settings) -> None: match settings.llm.mode: case "local": from llama_index.llms import LlamaCPP diff --git a/private_gpt/di.py b/private_gpt/di.py index 115c8892..05021b01 100644 --- a/private_gpt/di.py +++ b/private_gpt/di.py @@ -1,9 +1,19 @@ from injector import Injector +from private_gpt.settings.settings import Settings, unsafe_typed_settings + def create_application_injector() -> Injector: - injector = Injector(auto_bind=True) - return injector + _injector = Injector(auto_bind=True) + _injector.binder.bind(Settings, to=unsafe_typed_settings) + return _injector -root_injector: Injector = create_application_injector() +""" +Global injector for the application. + +Avoid using this reference, it will make your code harder to test. + +Instead, use the `request.state.injector` reference, which is bound to every request +""" +global_injector: Injector = create_application_injector() diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py new file mode 100644 index 00000000..e65f0edc --- /dev/null +++ b/private_gpt/launcher.py @@ -0,0 +1,128 @@ +"""FastAPI app creation, logger configuration and main API routes.""" +import logging +from typing import Any + +from fastapi import Depends, FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.utils import get_openapi +from injector import Injector + +from private_gpt.paths import docs_path +from private_gpt.server.chat.chat_router import chat_router +from private_gpt.server.chunks.chunks_router import chunks_router +from private_gpt.server.completions.completions_router import completions_router +from private_gpt.server.embeddings.embeddings_router import embeddings_router +from private_gpt.server.health.health_router import health_router +from private_gpt.server.ingest.ingest_router import ingest_router +from private_gpt.settings.settings import Settings + +logger = logging.getLogger(__name__) + + +def create_app(root_injector: Injector) -> FastAPI: + + # Start the API + with open(docs_path / "description.md") as description_file: + description = description_file.read() + + tags_metadata = [ + { + "name": "Ingestion", + "description": "High-level APIs covering document ingestion -internally " + "managing document parsing, splitting," + "metadata extraction, embedding generation and storage- and ingested " + "documents CRUD." + "Each ingested document is identified by an ID that can be used to filter the " + "context" + "used in *Contextual Completions* and *Context Chunks* APIs.", + }, + { + "name": "Contextual Completions", + "description": "High-level APIs covering contextual Chat and Completions. They " + "follow OpenAI's format, extending it to " + "allow using the context coming from ingested documents to create the " + "response. Internally" + "manage context retrieval, prompt engineering and the response generation.", + }, + { + "name": "Context Chunks", + "description": "Low-level API that given a query return relevant chunks of " + "text coming from the ingested" + "documents.", + }, + { + "name": "Embeddings", + "description": "Low-level API to obtain the vector representation of a given " + "text, using an Embeddings model." + "Follows OpenAI's embeddings API format.", + }, + { + "name": "Health", + "description": "Simple health API to make sure the server is up and running.", + }, + ] + + async def bind_injector_to_request(request: Request) -> None: + request.state.injector = root_injector + + app = FastAPI(dependencies=[Depends(bind_injector_to_request)]) + + def custom_openapi() -> dict[str, Any]: + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title="PrivateGPT", + description=description, + version="0.1.0", + summary="PrivateGPT is a production-ready AI project that allows you to " + "ask questions to your documents using the power of Large Language " + "Models (LLMs), even in scenarios without Internet connection. " + "100% private, no data leaves your execution environment at any point.", + contact={ + "url": "https://github.com/imartinez/privateGPT", + }, + license_info={ + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + }, + routes=app.routes, + tags=tags_metadata, + ) + openapi_schema["info"]["x-logo"] = { + "url": "https://lh3.googleusercontent.com/drive-viewer" + "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj" + "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560" + } + + app.openapi_schema = openapi_schema + return app.openapi_schema + + app.openapi = custom_openapi # type: ignore[method-assign] + + app.include_router(completions_router) + app.include_router(chat_router) + app.include_router(chunks_router) + app.include_router(ingest_router) + app.include_router(embeddings_router) + app.include_router(health_router) + + settings = root_injector.get(Settings) + if settings.server.cors.enabled: + logger.debug("Setting up CORS middleware") + app.add_middleware( + CORSMiddleware, + allow_credentials=settings.server.cors.allow_credentials, + allow_origins=settings.server.cors.allow_origins, + allow_origin_regex=settings.server.cors.allow_origin_regex, + allow_methods=settings.server.cors.allow_methods, + allow_headers=settings.server.cors.allow_headers, + ) + + if settings.ui.enabled: + logger.debug("Importing the UI module") + from private_gpt.ui.ui import PrivateGptUi + + ui = root_injector.get(PrivateGptUi) + ui.mount_in_app(app, settings.ui.path) + + return app diff --git a/private_gpt/main.py b/private_gpt/main.py index 519f205d..d249fa6c 100644 --- a/private_gpt/main.py +++ b/private_gpt/main.py @@ -1,124 +1,11 @@ """FastAPI app creation, logger configuration and main API routes.""" -import logging -from typing import Any import llama_index -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.openapi.utils import get_openapi -from private_gpt.paths import docs_path -from private_gpt.server.chat.chat_router import chat_router -from private_gpt.server.chunks.chunks_router import chunks_router -from private_gpt.server.completions.completions_router import completions_router -from private_gpt.server.embeddings.embeddings_router import embeddings_router -from private_gpt.server.health.health_router import health_router -from private_gpt.server.ingest.ingest_router import ingest_router -from private_gpt.settings.settings import settings - -logger = logging.getLogger(__name__) +from private_gpt.di import global_injector +from private_gpt.launcher import create_app # Add LlamaIndex simple observability llama_index.set_global_handler("simple") -# Start the API -with open(docs_path / "description.md") as description_file: - description = description_file.read() - -tags_metadata = [ - { - "name": "Ingestion", - "description": "High-level APIs covering document ingestion -internally " - "managing document parsing, splitting," - "metadata extraction, embedding generation and storage- and ingested " - "documents CRUD." - "Each ingested document is identified by an ID that can be used to filter the " - "context" - "used in *Contextual Completions* and *Context Chunks* APIs.", - }, - { - "name": "Contextual Completions", - "description": "High-level APIs covering contextual Chat and Completions. They " - "follow OpenAI's format, extending it to " - "allow using the context coming from ingested documents to create the " - "response. Internally" - "manage context retrieval, prompt engineering and the response generation.", - }, - { - "name": "Context Chunks", - "description": "Low-level API that given a query return relevant chunks of " - "text coming from the ingested" - "documents.", - }, - { - "name": "Embeddings", - "description": "Low-level API to obtain the vector representation of a given " - "text, using an Embeddings model." - "Follows OpenAI's embeddings API format.", - }, - { - "name": "Health", - "description": "Simple health API to make sure the server is up and running.", - }, -] - -app = FastAPI() - - -def custom_openapi() -> dict[str, Any]: - if app.openapi_schema: - return app.openapi_schema - openapi_schema = get_openapi( - title="PrivateGPT", - description=description, - version="0.1.0", - summary="PrivateGPT is a production-ready AI project that allows you to " - "ask questions to your documents using the power of Large Language " - "Models (LLMs), even in scenarios without Internet connection. " - "100% private, no data leaves your execution environment at any point.", - contact={ - "url": "https://github.com/imartinez/privateGPT", - }, - license_info={ - "name": "Apache 2.0", - "url": "https://www.apache.org/licenses/LICENSE-2.0.html", - }, - routes=app.routes, - tags=tags_metadata, - ) - openapi_schema["info"]["x-logo"] = { - "url": "https://lh3.googleusercontent.com/drive-viewer" - "/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj" - "E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560" - } - - app.openapi_schema = openapi_schema - return app.openapi_schema - - -app.openapi = custom_openapi # type: ignore[method-assign] - -app.include_router(completions_router) -app.include_router(chat_router) -app.include_router(chunks_router) -app.include_router(ingest_router) -app.include_router(embeddings_router) -app.include_router(health_router) - -if settings.server.cors.enabled: - logger.debug("Setting up CORS middleware") - app.add_middleware( - CORSMiddleware, - allow_credentials=settings.server.cors.allow_credentials, - allow_origins=settings.server.cors.allow_origins, - allow_origin_regex=settings.server.cors.allow_origin_regex, - allow_methods=settings.server.cors.allow_methods, - allow_headers=settings.server.cors.allow_headers, - ) - - -if settings.ui.enabled: - logger.debug("Importing the UI module") - from private_gpt.ui.ui import PrivateGptUi - - PrivateGptUi().mount_in_app(app) +app = create_app(global_injector) diff --git a/private_gpt/paths.py b/private_gpt/paths.py index 88310519..59db3a49 100644 --- a/private_gpt/paths.py +++ b/private_gpt/paths.py @@ -13,4 +13,6 @@ def _absolute_or_from_project_root(path: str) -> Path: models_path: Path = PROJECT_ROOT_PATH / "models" models_cache_path: Path = models_path / "cache" docs_path: Path = PROJECT_ROOT_PATH / "docs" -local_data_path: Path = _absolute_or_from_project_root(settings.data.local_data_folder) +local_data_path: Path = _absolute_or_from_project_root( + settings().data.local_data_folder +) diff --git a/private_gpt/server/chat/chat_router.py b/private_gpt/server/chat/chat_router.py index bd7034b4..79558f7f 100644 --- a/private_gpt/server/chat/chat_router.py +++ b/private_gpt/server/chat/chat_router.py @@ -1,9 +1,8 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from llama_index.llms import ChatMessage, MessageRole from pydantic import BaseModel from starlette.responses import StreamingResponse -from private_gpt.di import root_injector from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.openai_models import ( OpenAICompletion, @@ -52,7 +51,9 @@ class ChatBody(BaseModel): responses={200: {"model": OpenAICompletion}}, tags=["Contextual Completions"], ) -def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: +def chat_completion( + request: Request, body: ChatBody +) -> OpenAICompletion | StreamingResponse: """Given a list of messages comprising a conversation, return a response. If `use_context` is set to `true`, the model will use context coming @@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse: "finish_reason":null}]} ``` """ - service = root_injector.get(ChatService) + service = request.state.injector.get(ChatService) all_messages = [ ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages ] diff --git a/private_gpt/server/chunks/chunks_router.py b/private_gpt/server/chunks/chunks_router.py index d965d984..4da377ac 100644 --- a/private_gpt/server/chunks/chunks_router.py +++ b/private_gpt/server/chunks/chunks_router.py @@ -1,9 +1,8 @@ from typing import Literal -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from pydantic import BaseModel, Field -from private_gpt.di import root_injector from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.utils.auth import authenticated @@ -25,7 +24,7 @@ class ChunksResponse(BaseModel): @chunks_router.post("/chunks", tags=["Context Chunks"]) -def chunks_retrieval(body: ChunksBody) -> ChunksResponse: +def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse: """Given a `text`, returns the most relevant chunks from the ingested documents. The returned information can be used to generate prompts that can be @@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse: `/ingest/list` endpoint. If you want all ingested documents to be used, remove `context_filter` altogether. """ - service = root_injector.get(ChunksService) + service = request.state.injector.get(ChunksService) results = service.retrieve_relevant( body.text, body.context_filter, body.limit, body.prev_next_chunks ) diff --git a/private_gpt/server/completions/completions_router.py b/private_gpt/server/completions/completions_router.py index 4840047f..887923b4 100644 --- a/private_gpt/server/completions/completions_router.py +++ b/private_gpt/server/completions/completions_router.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from pydantic import BaseModel from starlette.responses import StreamingResponse @@ -41,7 +41,9 @@ class CompletionsBody(BaseModel): responses={200: {"model": OpenAICompletion}}, tags=["Contextual Completions"], ) -def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse: +def prompt_completion( + request: Request, body: CompletionsBody +) -> OpenAICompletion | StreamingResponse: """We recommend most users use our Chat completions API. Given a prompt, the model will return one predicted completion. If `use_context` @@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp include_sources=body.include_sources, context_filter=body.context_filter, ) - return chat_completion(chat_body) + return chat_completion(request, chat_body) diff --git a/private_gpt/server/embeddings/embeddings_router.py b/private_gpt/server/embeddings/embeddings_router.py index f5236c6a..f698392d 100644 --- a/private_gpt/server/embeddings/embeddings_router.py +++ b/private_gpt/server/embeddings/embeddings_router.py @@ -1,9 +1,8 @@ from typing import Literal -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from pydantic import BaseModel -from private_gpt.di import root_injector from private_gpt.server.embeddings.embeddings_service import ( Embedding, EmbeddingsService, @@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel): @embeddings_router.post("/embeddings", tags=["Embeddings"]) -def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse: +def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse: """Get a vector representation of a given input. That vector representation can be easily consumed by machine learning models and algorithms. """ - service = root_injector.get(EmbeddingsService) + service = request.state.injector.get(EmbeddingsService) input_texts = body.input if isinstance(body.input, list) else [body.input] embeddings = service.texts_embeddings(input_texts) return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings) diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index d682de7b..c06a4d41 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -1,9 +1,8 @@ from typing import Literal -from fastapi import APIRouter, Depends, HTTPException, UploadFile +from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from pydantic import BaseModel -from private_gpt.di import root_injector from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService from private_gpt.server.utils.auth import authenticated @@ -17,7 +16,7 @@ class IngestResponse(BaseModel): @ingest_router.post("/ingest", tags=["Ingestion"]) -def ingest(file: UploadFile) -> IngestResponse: +def ingest(request: Request, file: UploadFile) -> IngestResponse: """Ingests and processes a file, storing its chunks to be used as context. The context obtained from files is later used in @@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse: can be used to filter the context used to create responses in `/chat/completions`, `/completions`, and `/chunks` APIs. """ - service = root_injector.get(IngestService) + service = request.state.injector.get(IngestService) if file.filename is None: raise HTTPException(400, "No file name provided") ingested_documents = service.ingest(file.filename, file.file.read()) @@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse: @ingest_router.get("/ingest/list", tags=["Ingestion"]) -def list_ingested() -> IngestResponse: +def list_ingested(request: Request) -> IngestResponse: """Lists already ingested Documents including their Document ID and metadata. Those IDs can be used to filter the context used to create responses in `/chat/completions`, `/completions`, and `/chunks` APIs. """ - service = root_injector.get(IngestService) + service = request.state.injector.get(IngestService) ingested_documents = service.list_ingested() return IngestResponse(object="list", model="private-gpt", data=ingested_documents) @ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"]) -def delete_ingested(doc_id: str) -> None: +def delete_ingested(request: Request, doc_id: str) -> None: """Delete the specified ingested Document. The `doc_id` can be obtained from the `GET /ingest/list` endpoint. The document will be effectively deleted from your storage context. """ - service = root_injector.get(IngestService) + service = request.state.injector.get(IngestService) service.delete(doc_id) diff --git a/private_gpt/server/utils/auth.py b/private_gpt/server/utils/auth.py index 371e794d..4fd57a7f 100644 --- a/private_gpt/server/utils/auth.py +++ b/private_gpt/server/utils/auth.py @@ -38,13 +38,13 @@ logger = logging.getLogger(__name__) def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool: """Check if the request is authenticated.""" - if not secrets.compare_digest(authorization, settings.server.auth.secret): + if not secrets.compare_digest(authorization, settings().server.auth.secret): # If the "Authorization" header is not the expected one, raise an exception. raise NOT_AUTHENTICATED return True -if not settings.server.auth.enabled: +if not settings().server.auth.enabled: logger.debug( "Defining a dummy authentication mechanism for fastapi, always authenticating requests" ) @@ -62,7 +62,7 @@ else: _simple_authentication: Annotated[bool, Depends(_simple_authentication)] ) -> bool: """Check if the request is authenticated.""" - assert settings.server.auth.enabled + assert settings().server.auth.enabled if not _simple_authentication: raise NOT_AUTHENTICATED return True diff --git a/private_gpt/settings/settings.py b/private_gpt/settings/settings.py index 9529cac1..30443119 100644 --- a/private_gpt/settings/settings.py +++ b/private_gpt/settings/settings.py @@ -2,7 +2,7 @@ from typing import Literal from pydantic import BaseModel, Field -from private_gpt.settings.settings_loader import load_active_profiles +from private_gpt.settings.settings_loader import load_active_settings class CorsSettings(BaseModel): @@ -114,4 +114,29 @@ class Settings(BaseModel): openai: OpenAISettings -settings = Settings(**load_active_profiles()) +""" +This is visible just for DI or testing purposes. + +Use dependency injection or `settings()` method instead. +""" +unsafe_settings = load_active_settings() + +""" +This is visible just for DI or testing purposes. + +Use dependency injection or `settings()` method instead. +""" +unsafe_typed_settings = Settings(**unsafe_settings) + + +def settings() -> Settings: + """Get the current loaded settings from the DI container. + + This method exists to keep compatibility with the existing code, + that require global access to the settings. + + For regular components use dependency injection instead. + """ + from private_gpt.di import global_injector + + return global_injector.get(Settings) diff --git a/private_gpt/settings/settings_loader.py b/private_gpt/settings/settings_loader.py index 99c2ca40..b4052db2 100644 --- a/private_gpt/settings/settings_loader.py +++ b/private_gpt/settings/settings_loader.py @@ -2,6 +2,7 @@ import functools import logging import os import sys +from collections.abc import Iterable from pathlib import Path from typing import Any @@ -28,7 +29,11 @@ active_profiles: list[str] = unique_list( ) -def load_profile(profile: str) -> dict[str, Any]: +def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]: + return functools.reduce(deep_update, settings, {}) + + +def load_settings_from_profile(profile: str) -> dict[str, Any]: if profile == "default": profile_file_name = "settings.yaml" else: @@ -42,9 +47,11 @@ def load_profile(profile: str) -> dict[str, Any]: return config -def load_active_profiles() -> dict[str, Any]: +def load_active_settings() -> dict[str, Any]: """Load active profiles and merge them.""" logger.info("Starting application with profiles=%s", active_profiles) - loaded_profiles = [load_profile(profile) for profile in active_profiles] - merged: dict[str, Any] = functools.reduce(deep_update, loaded_profiles, {}) + loaded_profiles = [ + load_settings_from_profile(profile) for profile in active_profiles + ] + merged: dict[str, Any] = merge_settings(loaded_profiles) return merged diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index f4a1431a..dea99f50 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -8,10 +8,11 @@ from typing import Any, TextIO import gradio as gr # type: ignore from fastapi import FastAPI from gradio.themes.utils.colors import slate # type: ignore +from injector import inject, singleton from llama_index.llms import ChatMessage, ChatResponse, MessageRole from pydantic import BaseModel -from private_gpt.di import root_injector +from private_gpt.di import global_injector from private_gpt.server.chat.chat_service import ChatService, CompletionGen from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.ingest.ingest_service import IngestService @@ -48,11 +49,18 @@ class Source(BaseModel): return curated_sources +@singleton class PrivateGptUi: - def __init__(self) -> None: - self._ingest_service = root_injector.get(IngestService) - self._chat_service = root_injector.get(ChatService) - self._chunks_service = root_injector.get(ChunksService) + @inject + def __init__( + self, + ingest_service: IngestService, + chat_service: ChatService, + chunks_service: ChunksService, + ) -> None: + self._ingest_service = ingest_service + self._chat_service = chat_service + self._chunks_service = chunks_service # Cache the UI blocks self._ui_block = None @@ -198,7 +206,7 @@ class PrivateGptUi: _ = gr.ChatInterface( self._chat, chatbot=gr.Chatbot( - label=f"LLM: {settings.llm.mode}", + label=f"LLM: {settings().llm.mode}", show_copy_button=True, render=False, avatar_images=( @@ -217,16 +225,15 @@ class PrivateGptUi: self._ui_block = self._build_ui_blocks() return self._ui_block - def mount_in_app(self, app: FastAPI) -> None: + def mount_in_app(self, app: FastAPI, path: str) -> None: blocks = self.get_ui_blocks() blocks.queue() - base_path = settings.ui.path - logger.info("Mounting the gradio UI, at path=%s", base_path) - gr.mount_gradio_app(app, blocks, path=base_path) + logger.info("Mounting the gradio UI, at path=%s", path) + gr.mount_gradio_app(app, blocks, path=path) if __name__ == "__main__": - ui = PrivateGptUi() + ui = global_injector.get(PrivateGptUi) _blocks = ui.get_ui_blocks() _blocks.queue() _blocks.launch(debug=False, show_api=False) diff --git a/scripts/ingest_folder.py b/scripts/ingest_folder.py index f2bc24c5..cdb164ef 100644 --- a/scripts/ingest_folder.py +++ b/scripts/ingest_folder.py @@ -2,13 +2,13 @@ import argparse import logging from pathlib import Path -from private_gpt.di import root_injector +from private_gpt.di import global_injector from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_watcher import IngestWatcher logger = logging.getLogger(__name__) -ingest_service = root_injector.get(IngestService) +ingest_service = global_injector.get(IngestService) parser = argparse.ArgumentParser(prog="ingest_folder.py") parser.add_argument("folder", help="Folder to ingest") diff --git a/scripts/setup b/scripts/setup index fc56d139..377bbe0b 100755 --- a/scripts/setup +++ b/scripts/setup @@ -9,9 +9,9 @@ from private_gpt.settings.settings import settings os.makedirs(models_path, exist_ok=True) embedding_path = models_path / "embedding" -print(f"Downloading embedding {settings.local.embedding_hf_model_name}") +print(f"Downloading embedding {settings().local.embedding_hf_model_name}") snapshot_download( - repo_id=settings.local.embedding_hf_model_name, + repo_id=settings().local.embedding_hf_model_name, cache_dir=models_cache_path, local_dir=embedding_path, ) @@ -20,8 +20,8 @@ print("Downloading models for local execution...") # Download LLM and create a symlink to the model file hf_hub_download( - repo_id=settings.local.llm_hf_repo_id, - filename=settings.local.llm_hf_model_file, + repo_id=settings().local.llm_hf_repo_id, + filename=settings().local.llm_hf_model_file, cache_dir=models_cache_path, local_dir=models_path, ) diff --git a/settings-test.yaml b/settings-test.yaml index 965a0efe..5f7a190b 100644 --- a/settings-test.yaml +++ b/settings-test.yaml @@ -5,8 +5,12 @@ server: # Dummy secrets used for tests secret: "foo bar; dummy secret" + data: local_data_folder: local_data/tests llm: - mode: mock \ No newline at end of file + mode: mock + +ui: + enabled: false \ No newline at end of file diff --git a/tests/fixtures/fast_api_test_client.py b/tests/fixtures/fast_api_test_client.py index b91dfec0..77d6037c 100644 --- a/tests/fixtures/fast_api_test_client.py +++ b/tests/fixtures/fast_api_test_client.py @@ -1,15 +1,14 @@ import pytest -from fastapi import FastAPI from fastapi.testclient import TestClient -from private_gpt.main import app +from private_gpt.launcher import create_app +from tests.fixtures.mock_injector import MockInjector @pytest.fixture() -def current_test_app() -> FastAPI: - return app +def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient: + if request is not None and hasattr(request, "param"): + injector.bind_settings(request.param or {}) - -@pytest.fixture() -def test_client() -> TestClient: - return TestClient(app) + app_under_test = create_app(injector.test_injector) + return TestClient(app_under_test) diff --git a/tests/fixtures/mock_injector.py b/tests/fixtures/mock_injector.py index 5a74358f..5769b33d 100644 --- a/tests/fixtures/mock_injector.py +++ b/tests/fixtures/mock_injector.py @@ -1,10 +1,13 @@ from collections.abc import Callable +from typing import Any from unittest.mock import MagicMock import pytest from injector import Provider, ScopeDecorator, singleton from private_gpt.di import create_application_injector +from private_gpt.settings.settings import Settings, unsafe_settings +from private_gpt.settings.settings_loader import merge_settings from private_gpt.utils.typing import T @@ -24,6 +27,12 @@ class MockInjector: self.test_injector.binder.bind(interface, to=mock, scope=scope) return mock # type: ignore + def bind_settings(self, settings: dict[str, Any]) -> Settings: + merged = merge_settings([unsafe_settings, settings]) + new_settings = Settings(**merged) + self.test_injector.binder.bind(Settings, new_settings) + return new_settings + def get(self, interface: type[T]) -> T: return self.test_injector.get(interface) diff --git a/tests/server/utils/test_simple_auth.py b/tests/server/utils/test_simple_auth.py index 6c304a57..0ef3614c 100644 --- a/tests/server/utils/test_simple_auth.py +++ b/tests/server/utils/test_simple_auth.py @@ -8,7 +8,7 @@ NOTE: We are not testing the switch based on the config in from typing import Annotated import pytest -from fastapi import Depends, FastAPI +from fastapi import Depends from fastapi.testclient import TestClient from private_gpt.server.utils.auth import ( @@ -29,15 +29,16 @@ def _copy_simple_authenticated( @pytest.fixture(autouse=True) -def _patch_authenticated_dependency(current_test_app: FastAPI): +def _patch_authenticated_dependency(test_client: TestClient): # Patch the server to use simple authentication - current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated + + test_client.app.dependency_overrides[authenticated] = _copy_simple_authenticated # Call the actual test yield # Remove the patch for other tests - current_test_app.dependency_overrides = {} + test_client.app.dependency_overrides = {} def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None: @@ -50,6 +51,6 @@ def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None: assert response_fail.status_code == 401 response_success = test_client.get( - "/v1/ingest/list", headers={"Authorization": settings.server.auth.secret} + "/v1/ingest/list", headers={"Authorization": settings().server.auth.secret} ) assert response_success.status_code == 200 diff --git a/tests/settings/test_settings.py b/tests/settings/test_settings.py index 3178967a..f10bdd81 100644 --- a/tests/settings/test_settings.py +++ b/tests/settings/test_settings.py @@ -1,5 +1,12 @@ -from private_gpt.settings.settings import settings +from private_gpt.settings.settings import Settings, settings +from tests.fixtures.mock_injector import MockInjector def test_settings_are_loaded_and_merged() -> None: - assert settings.server.env_name == "test" + assert settings().server.env_name == "test" + + +def test_settings_can_be_overriden(injector: MockInjector) -> None: + injector.bind_settings({"server": {"env_name": "overriden"}}) + mocked_settings = injector.get(Settings) + assert mocked_settings.server.env_name == "overriden" diff --git a/tests/ui/test_ui.py b/tests/ui/test_ui.py new file mode 100644 index 00000000..4f4361b0 --- /dev/null +++ b/tests/ui/test_ui.py @@ -0,0 +1,10 @@ +import pytest +from fastapi.testclient import TestClient + + +@pytest.mark.parametrize( + "test_client", [{"ui": {"enabled": True, "path": "/ui"}}], indirect=True +) +def test_ui_starts_in_the_given_endpoint(test_client: TestClient) -> None: + response = test_client.get("/ui") + assert response.status_code == 200