mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-04-27 11:21:34 +00:00
fix: Remove global state (#1216)
* Remove all global settings state * chore: remove autogenerated class * chore: cleanup * chore: merge conflicts
This commit is contained in:
parent
f394ca61bb
commit
022bd718e3
@ -8,4 +8,4 @@ from private_gpt.settings.settings import settings
|
||||
# Set log_config=None to do not use the uvicorn logging configuration, and
|
||||
# use ours instead. For reference, see below:
|
||||
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
|
||||
uvicorn.run(app, host="0.0.0.0", port=settings.server.port, log_config=None)
|
||||
uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)
|
||||
|
@ -3,7 +3,7 @@ from llama_index import MockEmbedding
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
|
||||
from private_gpt.paths import models_cache_path
|
||||
from private_gpt.settings.settings import settings
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
|
||||
@singleton
|
||||
@ -11,7 +11,7 @@ class EmbeddingComponent:
|
||||
embedding_model: BaseEmbedding
|
||||
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
match settings.llm.mode:
|
||||
case "local":
|
||||
from llama_index.embeddings import HuggingFaceEmbedding
|
||||
|
@ -4,7 +4,7 @@ from llama_index.llms.base import LLM
|
||||
from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt
|
||||
|
||||
from private_gpt.paths import models_path
|
||||
from private_gpt.settings.settings import settings
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
|
||||
@singleton
|
||||
@ -12,7 +12,7 @@ class LLMComponent:
|
||||
llm: LLM
|
||||
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, settings: Settings) -> None:
|
||||
match settings.llm.mode:
|
||||
case "local":
|
||||
from llama_index.llms import LlamaCPP
|
||||
|
@ -1,9 +1,19 @@
|
||||
from injector import Injector
|
||||
|
||||
from private_gpt.settings.settings import Settings, unsafe_typed_settings
|
||||
|
||||
|
||||
def create_application_injector() -> Injector:
|
||||
injector = Injector(auto_bind=True)
|
||||
return injector
|
||||
_injector = Injector(auto_bind=True)
|
||||
_injector.binder.bind(Settings, to=unsafe_typed_settings)
|
||||
return _injector
|
||||
|
||||
|
||||
root_injector: Injector = create_application_injector()
|
||||
"""
|
||||
Global injector for the application.
|
||||
|
||||
Avoid using this reference, it will make your code harder to test.
|
||||
|
||||
Instead, use the `request.state.injector` reference, which is bound to every request
|
||||
"""
|
||||
global_injector: Injector = create_application_injector()
|
||||
|
128
private_gpt/launcher.py
Normal file
128
private_gpt/launcher.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""FastAPI app creation, logger configuration and main API routes."""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from injector import Injector
|
||||
|
||||
from private_gpt.paths import docs_path
|
||||
from private_gpt.server.chat.chat_router import chat_router
|
||||
from private_gpt.server.chunks.chunks_router import chunks_router
|
||||
from private_gpt.server.completions.completions_router import completions_router
|
||||
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
||||
from private_gpt.server.health.health_router import health_router
|
||||
from private_gpt.server.ingest.ingest_router import ingest_router
|
||||
from private_gpt.settings.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(root_injector: Injector) -> FastAPI:
|
||||
|
||||
# Start the API
|
||||
with open(docs_path / "description.md") as description_file:
|
||||
description = description_file.read()
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "Ingestion",
|
||||
"description": "High-level APIs covering document ingestion -internally "
|
||||
"managing document parsing, splitting,"
|
||||
"metadata extraction, embedding generation and storage- and ingested "
|
||||
"documents CRUD."
|
||||
"Each ingested document is identified by an ID that can be used to filter the "
|
||||
"context"
|
||||
"used in *Contextual Completions* and *Context Chunks* APIs.",
|
||||
},
|
||||
{
|
||||
"name": "Contextual Completions",
|
||||
"description": "High-level APIs covering contextual Chat and Completions. They "
|
||||
"follow OpenAI's format, extending it to "
|
||||
"allow using the context coming from ingested documents to create the "
|
||||
"response. Internally"
|
||||
"manage context retrieval, prompt engineering and the response generation.",
|
||||
},
|
||||
{
|
||||
"name": "Context Chunks",
|
||||
"description": "Low-level API that given a query return relevant chunks of "
|
||||
"text coming from the ingested"
|
||||
"documents.",
|
||||
},
|
||||
{
|
||||
"name": "Embeddings",
|
||||
"description": "Low-level API to obtain the vector representation of a given "
|
||||
"text, using an Embeddings model."
|
||||
"Follows OpenAI's embeddings API format.",
|
||||
},
|
||||
{
|
||||
"name": "Health",
|
||||
"description": "Simple health API to make sure the server is up and running.",
|
||||
},
|
||||
]
|
||||
|
||||
async def bind_injector_to_request(request: Request) -> None:
|
||||
request.state.injector = root_injector
|
||||
|
||||
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
|
||||
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title="PrivateGPT",
|
||||
description=description,
|
||||
version="0.1.0",
|
||||
summary="PrivateGPT is a production-ready AI project that allows you to "
|
||||
"ask questions to your documents using the power of Large Language "
|
||||
"Models (LLMs), even in scenarios without Internet connection. "
|
||||
"100% private, no data leaves your execution environment at any point.",
|
||||
contact={
|
||||
"url": "https://github.com/imartinez/privateGPT",
|
||||
},
|
||||
license_info={
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
},
|
||||
routes=app.routes,
|
||||
tags=tags_metadata,
|
||||
)
|
||||
openapi_schema["info"]["x-logo"] = {
|
||||
"url": "https://lh3.googleusercontent.com/drive-viewer"
|
||||
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
|
||||
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
app.openapi = custom_openapi # type: ignore[method-assign]
|
||||
|
||||
app.include_router(completions_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
settings = root_injector.get(Settings)
|
||||
if settings.server.cors.enabled:
|
||||
logger.debug("Setting up CORS middleware")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=settings.server.cors.allow_credentials,
|
||||
allow_origins=settings.server.cors.allow_origins,
|
||||
allow_origin_regex=settings.server.cors.allow_origin_regex,
|
||||
allow_methods=settings.server.cors.allow_methods,
|
||||
allow_headers=settings.server.cors.allow_headers,
|
||||
)
|
||||
|
||||
if settings.ui.enabled:
|
||||
logger.debug("Importing the UI module")
|
||||
from private_gpt.ui.ui import PrivateGptUi
|
||||
|
||||
ui = root_injector.get(PrivateGptUi)
|
||||
ui.mount_in_app(app, settings.ui.path)
|
||||
|
||||
return app
|
@ -1,124 +1,11 @@
|
||||
"""FastAPI app creation, logger configuration and main API routes."""
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import llama_index
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
|
||||
from private_gpt.paths import docs_path
|
||||
from private_gpt.server.chat.chat_router import chat_router
|
||||
from private_gpt.server.chunks.chunks_router import chunks_router
|
||||
from private_gpt.server.completions.completions_router import completions_router
|
||||
from private_gpt.server.embeddings.embeddings_router import embeddings_router
|
||||
from private_gpt.server.health.health_router import health_router
|
||||
from private_gpt.server.ingest.ingest_router import ingest_router
|
||||
from private_gpt.settings.settings import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from private_gpt.di import global_injector
|
||||
from private_gpt.launcher import create_app
|
||||
|
||||
# Add LlamaIndex simple observability
|
||||
llama_index.set_global_handler("simple")
|
||||
|
||||
# Start the API
|
||||
with open(docs_path / "description.md") as description_file:
|
||||
description = description_file.read()
|
||||
|
||||
tags_metadata = [
|
||||
{
|
||||
"name": "Ingestion",
|
||||
"description": "High-level APIs covering document ingestion -internally "
|
||||
"managing document parsing, splitting,"
|
||||
"metadata extraction, embedding generation and storage- and ingested "
|
||||
"documents CRUD."
|
||||
"Each ingested document is identified by an ID that can be used to filter the "
|
||||
"context"
|
||||
"used in *Contextual Completions* and *Context Chunks* APIs.",
|
||||
},
|
||||
{
|
||||
"name": "Contextual Completions",
|
||||
"description": "High-level APIs covering contextual Chat and Completions. They "
|
||||
"follow OpenAI's format, extending it to "
|
||||
"allow using the context coming from ingested documents to create the "
|
||||
"response. Internally"
|
||||
"manage context retrieval, prompt engineering and the response generation.",
|
||||
},
|
||||
{
|
||||
"name": "Context Chunks",
|
||||
"description": "Low-level API that given a query return relevant chunks of "
|
||||
"text coming from the ingested"
|
||||
"documents.",
|
||||
},
|
||||
{
|
||||
"name": "Embeddings",
|
||||
"description": "Low-level API to obtain the vector representation of a given "
|
||||
"text, using an Embeddings model."
|
||||
"Follows OpenAI's embeddings API format.",
|
||||
},
|
||||
{
|
||||
"name": "Health",
|
||||
"description": "Simple health API to make sure the server is up and running.",
|
||||
},
|
||||
]
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def custom_openapi() -> dict[str, Any]:
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
openapi_schema = get_openapi(
|
||||
title="PrivateGPT",
|
||||
description=description,
|
||||
version="0.1.0",
|
||||
summary="PrivateGPT is a production-ready AI project that allows you to "
|
||||
"ask questions to your documents using the power of Large Language "
|
||||
"Models (LLMs), even in scenarios without Internet connection. "
|
||||
"100% private, no data leaves your execution environment at any point.",
|
||||
contact={
|
||||
"url": "https://github.com/imartinez/privateGPT",
|
||||
},
|
||||
license_info={
|
||||
"name": "Apache 2.0",
|
||||
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
|
||||
},
|
||||
routes=app.routes,
|
||||
tags=tags_metadata,
|
||||
)
|
||||
openapi_schema["info"]["x-logo"] = {
|
||||
"url": "https://lh3.googleusercontent.com/drive-viewer"
|
||||
"/AK7aPaD_iNlMoTquOBsw4boh4tIYxyEuhz6EtEs8nzq3yNkNAK00xGj"
|
||||
"E1KUCmPJSk3TYOjcs6tReG6w_cLu1S7L_gPgT9z52iw=s2560"
|
||||
}
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi # type: ignore[method-assign]
|
||||
|
||||
app.include_router(completions_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(chunks_router)
|
||||
app.include_router(ingest_router)
|
||||
app.include_router(embeddings_router)
|
||||
app.include_router(health_router)
|
||||
|
||||
if settings.server.cors.enabled:
|
||||
logger.debug("Setting up CORS middleware")
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=settings.server.cors.allow_credentials,
|
||||
allow_origins=settings.server.cors.allow_origins,
|
||||
allow_origin_regex=settings.server.cors.allow_origin_regex,
|
||||
allow_methods=settings.server.cors.allow_methods,
|
||||
allow_headers=settings.server.cors.allow_headers,
|
||||
)
|
||||
|
||||
|
||||
if settings.ui.enabled:
|
||||
logger.debug("Importing the UI module")
|
||||
from private_gpt.ui.ui import PrivateGptUi
|
||||
|
||||
PrivateGptUi().mount_in_app(app)
|
||||
app = create_app(global_injector)
|
||||
|
@ -13,4 +13,6 @@ def _absolute_or_from_project_root(path: str) -> Path:
|
||||
models_path: Path = PROJECT_ROOT_PATH / "models"
|
||||
models_cache_path: Path = models_path / "cache"
|
||||
docs_path: Path = PROJECT_ROOT_PATH / "docs"
|
||||
local_data_path: Path = _absolute_or_from_project_root(settings.data.local_data_folder)
|
||||
local_data_path: Path = _absolute_or_from_project_root(
|
||||
settings().data.local_data_folder
|
||||
)
|
||||
|
@ -1,9 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
OpenAICompletion,
|
||||
@ -52,7 +51,9 @@ class ChatBody(BaseModel):
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
)
|
||||
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
|
||||
def chat_completion(
|
||||
request: Request, body: ChatBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""Given a list of messages comprising a conversation, return a response.
|
||||
|
||||
If `use_context` is set to `true`, the model will use context coming
|
||||
@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
service = root_injector.get(ChatService)
|
||||
service = request.state.injector.get(ChatService)
|
||||
all_messages = [
|
||||
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
|
||||
]
|
||||
|
@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):
|
||||
|
||||
|
||||
@chunks_router.post("/chunks", tags=["Context Chunks"])
|
||||
def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
|
||||
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
|
||||
"""Given a `text`, returns the most relevant chunks from the ingested documents.
|
||||
|
||||
The returned information can be used to generate prompts that can be
|
||||
@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
|
||||
`/ingest/list` endpoint. If you want all ingested documents to be used,
|
||||
remove `context_filter` altogether.
|
||||
"""
|
||||
service = root_injector.get(ChunksService)
|
||||
service = request.state.injector.get(ChunksService)
|
||||
results = service.retrieve_relevant(
|
||||
body.text, body.context_filter, body.limit, body.prev_next_chunks
|
||||
)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@ -41,7 +41,9 @@ class CompletionsBody(BaseModel):
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
)
|
||||
def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse:
|
||||
def prompt_completion(
|
||||
request: Request, body: CompletionsBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""We recommend most users use our Chat completions API.
|
||||
|
||||
Given a prompt, the model will return one predicted completion. If `use_context`
|
||||
@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
|
||||
include_sources=body.include_sources,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return chat_completion(chat_body)
|
||||
return chat_completion(request, chat_body)
|
||||
|
@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.server.embeddings.embeddings_service import (
|
||||
Embedding,
|
||||
EmbeddingsService,
|
||||
@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel):
|
||||
|
||||
|
||||
@embeddings_router.post("/embeddings", tags=["Embeddings"])
|
||||
def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse:
|
||||
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
|
||||
"""Get a vector representation of a given input.
|
||||
|
||||
That vector representation can be easily consumed
|
||||
by machine learning models and algorithms.
|
||||
"""
|
||||
service = root_injector.get(EmbeddingsService)
|
||||
service = request.state.injector.get(EmbeddingsService)
|
||||
input_texts = body.input if isinstance(body.input, list) else [body.input]
|
||||
embeddings = service.texts_embeddings(input_texts)
|
||||
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
|
||||
|
@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
@ -17,7 +16,7 @@ class IngestResponse(BaseModel):
|
||||
|
||||
|
||||
@ingest_router.post("/ingest", tags=["Ingestion"])
|
||||
def ingest(file: UploadFile) -> IngestResponse:
|
||||
def ingest(request: Request, file: UploadFile) -> IngestResponse:
|
||||
"""Ingests and processes a file, storing its chunks to be used as context.
|
||||
|
||||
The context obtained from files is later used in
|
||||
@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse:
|
||||
can be used to filter the context used to create responses in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
if file.filename is None:
|
||||
raise HTTPException(400, "No file name provided")
|
||||
ingested_documents = service.ingest(file.filename, file.file.read())
|
||||
@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse:
|
||||
|
||||
|
||||
@ingest_router.get("/ingest/list", tags=["Ingestion"])
|
||||
def list_ingested() -> IngestResponse:
|
||||
def list_ingested(request: Request) -> IngestResponse:
|
||||
"""Lists already ingested Documents including their Document ID and metadata.
|
||||
|
||||
Those IDs can be used to filter the context used to create responses
|
||||
in `/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
ingested_documents = service.list_ingested()
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
|
||||
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
|
||||
def delete_ingested(doc_id: str) -> None:
|
||||
def delete_ingested(request: Request, doc_id: str) -> None:
|
||||
"""Delete the specified ingested Document.
|
||||
|
||||
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
|
||||
The document will be effectively deleted from your storage context.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
service.delete(doc_id)
|
||||
|
@ -38,13 +38,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
if not secrets.compare_digest(authorization, settings.server.auth.secret):
|
||||
if not secrets.compare_digest(authorization, settings().server.auth.secret):
|
||||
# If the "Authorization" header is not the expected one, raise an exception.
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
|
||||
|
||||
if not settings.server.auth.enabled:
|
||||
if not settings().server.auth.enabled:
|
||||
logger.debug(
|
||||
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
|
||||
)
|
||||
@ -62,7 +62,7 @@ else:
|
||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
|
||||
) -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
assert settings.server.auth.enabled
|
||||
assert settings().server.auth.enabled
|
||||
if not _simple_authentication:
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
|
@ -2,7 +2,7 @@ from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.settings.settings_loader import load_active_profiles
|
||||
from private_gpt.settings.settings_loader import load_active_settings
|
||||
|
||||
|
||||
class CorsSettings(BaseModel):
|
||||
@ -114,4 +114,29 @@ class Settings(BaseModel):
|
||||
openai: OpenAISettings
|
||||
|
||||
|
||||
settings = Settings(**load_active_profiles())
|
||||
"""
|
||||
This is visible just for DI or testing purposes.
|
||||
|
||||
Use dependency injection or `settings()` method instead.
|
||||
"""
|
||||
unsafe_settings = load_active_settings()
|
||||
|
||||
"""
|
||||
This is visible just for DI or testing purposes.
|
||||
|
||||
Use dependency injection or `settings()` method instead.
|
||||
"""
|
||||
unsafe_typed_settings = Settings(**unsafe_settings)
|
||||
|
||||
|
||||
def settings() -> Settings:
|
||||
"""Get the current loaded settings from the DI container.
|
||||
|
||||
This method exists to keep compatibility with the existing code,
|
||||
that require global access to the settings.
|
||||
|
||||
For regular components use dependency injection instead.
|
||||
"""
|
||||
from private_gpt.di import global_injector
|
||||
|
||||
return global_injector.get(Settings)
|
||||
|
@ -2,6 +2,7 @@ import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -28,7 +29,11 @@ active_profiles: list[str] = unique_list(
|
||||
)
|
||||
|
||||
|
||||
def load_profile(profile: str) -> dict[str, Any]:
|
||||
def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
|
||||
return functools.reduce(deep_update, settings, {})
|
||||
|
||||
|
||||
def load_settings_from_profile(profile: str) -> dict[str, Any]:
|
||||
if profile == "default":
|
||||
profile_file_name = "settings.yaml"
|
||||
else:
|
||||
@ -42,9 +47,11 @@ def load_profile(profile: str) -> dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
def load_active_profiles() -> dict[str, Any]:
|
||||
def load_active_settings() -> dict[str, Any]:
|
||||
"""Load active profiles and merge them."""
|
||||
logger.info("Starting application with profiles=%s", active_profiles)
|
||||
loaded_profiles = [load_profile(profile) for profile in active_profiles]
|
||||
merged: dict[str, Any] = functools.reduce(deep_update, loaded_profiles, {})
|
||||
loaded_profiles = [
|
||||
load_settings_from_profile(profile) for profile in active_profiles
|
||||
]
|
||||
merged: dict[str, Any] = merge_settings(loaded_profiles)
|
||||
return merged
|
||||
|
@ -8,10 +8,11 @@ from typing import Any, TextIO
|
||||
import gradio as gr # type: ignore
|
||||
from fastapi import FastAPI
|
||||
from gradio.themes.utils.colors import slate # type: ignore
|
||||
from injector import inject, singleton
|
||||
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.di import global_injector
|
||||
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
@ -48,11 +49,18 @@ class Source(BaseModel):
|
||||
return curated_sources
|
||||
|
||||
|
||||
@singleton
|
||||
class PrivateGptUi:
|
||||
def __init__(self) -> None:
|
||||
self._ingest_service = root_injector.get(IngestService)
|
||||
self._chat_service = root_injector.get(ChatService)
|
||||
self._chunks_service = root_injector.get(ChunksService)
|
||||
@inject
|
||||
def __init__(
|
||||
self,
|
||||
ingest_service: IngestService,
|
||||
chat_service: ChatService,
|
||||
chunks_service: ChunksService,
|
||||
) -> None:
|
||||
self._ingest_service = ingest_service
|
||||
self._chat_service = chat_service
|
||||
self._chunks_service = chunks_service
|
||||
|
||||
# Cache the UI blocks
|
||||
self._ui_block = None
|
||||
@ -198,7 +206,7 @@ class PrivateGptUi:
|
||||
_ = gr.ChatInterface(
|
||||
self._chat,
|
||||
chatbot=gr.Chatbot(
|
||||
label=f"LLM: {settings.llm.mode}",
|
||||
label=f"LLM: {settings().llm.mode}",
|
||||
show_copy_button=True,
|
||||
render=False,
|
||||
avatar_images=(
|
||||
@ -217,16 +225,15 @@ class PrivateGptUi:
|
||||
self._ui_block = self._build_ui_blocks()
|
||||
return self._ui_block
|
||||
|
||||
def mount_in_app(self, app: FastAPI) -> None:
|
||||
def mount_in_app(self, app: FastAPI, path: str) -> None:
|
||||
blocks = self.get_ui_blocks()
|
||||
blocks.queue()
|
||||
base_path = settings.ui.path
|
||||
logger.info("Mounting the gradio UI, at path=%s", base_path)
|
||||
gr.mount_gradio_app(app, blocks, path=base_path)
|
||||
logger.info("Mounting the gradio UI, at path=%s", path)
|
||||
gr.mount_gradio_app(app, blocks, path=path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ui = PrivateGptUi()
|
||||
ui = global_injector.get(PrivateGptUi)
|
||||
_blocks = ui.get_ui_blocks()
|
||||
_blocks.queue()
|
||||
_blocks.launch(debug=False, show_api=False)
|
||||
|
@ -2,13 +2,13 @@ import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.di import global_injector
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from private_gpt.server.ingest.ingest_watcher import IngestWatcher
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ingest_service = root_injector.get(IngestService)
|
||||
ingest_service = global_injector.get(IngestService)
|
||||
|
||||
parser = argparse.ArgumentParser(prog="ingest_folder.py")
|
||||
parser.add_argument("folder", help="Folder to ingest")
|
||||
|
@ -9,9 +9,9 @@ from private_gpt.settings.settings import settings
|
||||
os.makedirs(models_path, exist_ok=True)
|
||||
embedding_path = models_path / "embedding"
|
||||
|
||||
print(f"Downloading embedding {settings.local.embedding_hf_model_name}")
|
||||
print(f"Downloading embedding {settings().local.embedding_hf_model_name}")
|
||||
snapshot_download(
|
||||
repo_id=settings.local.embedding_hf_model_name,
|
||||
repo_id=settings().local.embedding_hf_model_name,
|
||||
cache_dir=models_cache_path,
|
||||
local_dir=embedding_path,
|
||||
)
|
||||
@ -20,8 +20,8 @@ print("Downloading models for local execution...")
|
||||
|
||||
# Download LLM and create a symlink to the model file
|
||||
hf_hub_download(
|
||||
repo_id=settings.local.llm_hf_repo_id,
|
||||
filename=settings.local.llm_hf_model_file,
|
||||
repo_id=settings().local.llm_hf_repo_id,
|
||||
filename=settings().local.llm_hf_model_file,
|
||||
cache_dir=models_cache_path,
|
||||
local_dir=models_path,
|
||||
)
|
||||
|
@ -5,8 +5,12 @@ server:
|
||||
# Dummy secrets used for tests
|
||||
secret: "foo bar; dummy secret"
|
||||
|
||||
|
||||
data:
|
||||
local_data_folder: local_data/tests
|
||||
|
||||
llm:
|
||||
mode: mock
|
||||
mode: mock
|
||||
|
||||
ui:
|
||||
enabled: false
|
15
tests/fixtures/fast_api_test_client.py
vendored
15
tests/fixtures/fast_api_test_client.py
vendored
@ -1,15 +1,14 @@
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from private_gpt.main import app
|
||||
from private_gpt.launcher import create_app
|
||||
from tests.fixtures.mock_injector import MockInjector
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def current_test_app() -> FastAPI:
|
||||
return app
|
||||
def test_client(request: pytest.FixtureRequest, injector: MockInjector) -> TestClient:
|
||||
if request is not None and hasattr(request, "param"):
|
||||
injector.bind_settings(request.param or {})
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def test_client() -> TestClient:
|
||||
return TestClient(app)
|
||||
app_under_test = create_app(injector.test_injector)
|
||||
return TestClient(app_under_test)
|
||||
|
9
tests/fixtures/mock_injector.py
vendored
9
tests/fixtures/mock_injector.py
vendored
@ -1,10 +1,13 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from injector import Provider, ScopeDecorator, singleton
|
||||
|
||||
from private_gpt.di import create_application_injector
|
||||
from private_gpt.settings.settings import Settings, unsafe_settings
|
||||
from private_gpt.settings.settings_loader import merge_settings
|
||||
from private_gpt.utils.typing import T
|
||||
|
||||
|
||||
@ -24,6 +27,12 @@ class MockInjector:
|
||||
self.test_injector.binder.bind(interface, to=mock, scope=scope)
|
||||
return mock # type: ignore
|
||||
|
||||
def bind_settings(self, settings: dict[str, Any]) -> Settings:
|
||||
merged = merge_settings([unsafe_settings, settings])
|
||||
new_settings = Settings(**merged)
|
||||
self.test_injector.binder.bind(Settings, new_settings)
|
||||
return new_settings
|
||||
|
||||
def get(self, interface: type[T]) -> T:
|
||||
return self.test_injector.get(interface)
|
||||
|
||||
|
@ -8,7 +8,7 @@ NOTE: We are not testing the switch based on the config in
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi import Depends
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from private_gpt.server.utils.auth import (
|
||||
@ -29,15 +29,16 @@ def _copy_simple_authenticated(
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _patch_authenticated_dependency(current_test_app: FastAPI):
|
||||
def _patch_authenticated_dependency(test_client: TestClient):
|
||||
# Patch the server to use simple authentication
|
||||
current_test_app.dependency_overrides[authenticated] = _copy_simple_authenticated
|
||||
|
||||
test_client.app.dependency_overrides[authenticated] = _copy_simple_authenticated
|
||||
|
||||
# Call the actual test
|
||||
yield
|
||||
|
||||
# Remove the patch for other tests
|
||||
current_test_app.dependency_overrides = {}
|
||||
test_client.app.dependency_overrides = {}
|
||||
|
||||
|
||||
def test_default_auth_working_when_enabled_401(test_client: TestClient) -> None:
|
||||
@ -50,6 +51,6 @@ def test_default_auth_working_when_enabled_200(test_client: TestClient) -> None:
|
||||
assert response_fail.status_code == 401
|
||||
|
||||
response_success = test_client.get(
|
||||
"/v1/ingest/list", headers={"Authorization": settings.server.auth.secret}
|
||||
"/v1/ingest/list", headers={"Authorization": settings().server.auth.secret}
|
||||
)
|
||||
assert response_success.status_code == 200
|
||||
|
@ -1,5 +1,12 @@
|
||||
from private_gpt.settings.settings import settings
|
||||
from private_gpt.settings.settings import Settings, settings
|
||||
from tests.fixtures.mock_injector import MockInjector
|
||||
|
||||
|
||||
def test_settings_are_loaded_and_merged() -> None:
|
||||
assert settings.server.env_name == "test"
|
||||
assert settings().server.env_name == "test"
|
||||
|
||||
|
||||
def test_settings_can_be_overriden(injector: MockInjector) -> None:
|
||||
injector.bind_settings({"server": {"env_name": "overriden"}})
|
||||
mocked_settings = injector.get(Settings)
|
||||
assert mocked_settings.server.env_name == "overriden"
|
||||
|
10
tests/ui/test_ui.py
Normal file
10
tests/ui/test_ui.py
Normal file
@ -0,0 +1,10 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_client", [{"ui": {"enabled": True, "path": "/ui"}}], indirect=True
|
||||
)
|
||||
def test_ui_starts_in_the_given_endpoint(test_client: TestClient) -> None:
|
||||
response = test_client.get("/ui")
|
||||
assert response.status_code == 200
|
Loading…
Reference in New Issue
Block a user