fix: Remove global state (#1216)

* Remove all global settings state

* chore: remove autogenerated class

* chore: cleanup

* chore: merge conflicts
This commit is contained in:
Pablo Orgaz
2023-11-12 22:20:36 +01:00
committed by GitHub
parent f394ca61bb
commit 022bd718e3
24 changed files with 286 additions and 190 deletions

View File

@@ -1,9 +1,8 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from llama_index.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
@@ -52,7 +51,9 @@ class ChatBody(BaseModel):
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response.
If `use_context` is set to `true`, the model will use context coming
@@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
"finish_reason":null}]}
```
"""
service = root_injector.get(ChatService)
service = request.state.injector.get(ChatService)
all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
]

View File

@@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field
from private_gpt.di import root_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.utils.auth import authenticated
@@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):
@chunks_router.post("/chunks", tags=["Context Chunks"])
def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
"""Given a `text`, returns the most relevant chunks from the ingested documents.
The returned information can be used to generate prompts that can be
@@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
`/ingest/list` endpoint. If you want all ingested documents to be used,
remove `context_filter` altogether.
"""
service = root_injector.get(ChunksService)
service = request.state.injector.get(ChunksService)
results = service.retrieve_relevant(
body.text, body.context_filter, body.limit, body.prev_next_chunks
)

View File

@@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from starlette.responses import StreamingResponse
@@ -41,7 +41,9 @@ class CompletionsBody(BaseModel):
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
)
def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse:
def prompt_completion(
request: Request, body: CompletionsBody
) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion. If `use_context`
@@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
include_sources=body.include_sources,
context_filter=body.context_filter,
)
return chat_completion(chat_body)
return chat_completion(request, chat_body)

View File

@@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.embeddings.embeddings_service import (
Embedding,
EmbeddingsService,
@@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel):
@embeddings_router.post("/embeddings", tags=["Embeddings"])
def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse:
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
"""Get a vector representation of a given input.
That vector representation can be easily consumed
by machine learning models and algorithms.
"""
service = root_injector.get(EmbeddingsService)
service = request.state.injector.get(EmbeddingsService)
input_texts = body.input if isinstance(body.input, list) else [body.input]
embeddings = service.texts_embeddings(input_texts)
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)

View File

@@ -1,9 +1,8 @@
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
from private_gpt.server.utils.auth import authenticated
@@ -17,7 +16,7 @@ class IngestResponse(BaseModel):
@ingest_router.post("/ingest", tags=["Ingestion"])
def ingest(file: UploadFile) -> IngestResponse:
def ingest(request: Request, file: UploadFile) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in
@@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse:
can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
if file.filename is None:
raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest(file.filename, file.file.read())
@@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse:
@ingest_router.get("/ingest/list", tags=["Ingestion"])
def list_ingested() -> IngestResponse:
def list_ingested(request: Request) -> IngestResponse:
"""Lists already ingested Documents including their Document ID and metadata.
Those IDs can be used to filter the context used to create responses
in `/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
ingested_documents = service.list_ingested()
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
def delete_ingested(doc_id: str) -> None:
def delete_ingested(request: Request, doc_id: str) -> None:
"""Delete the specified ingested Document.
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
The document will be effectively deleted from your storage context.
"""
service = root_injector.get(IngestService)
service = request.state.injector.get(IngestService)
service.delete(doc_id)

View File

@@ -38,13 +38,13 @@ logger = logging.getLogger(__name__)
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
"""Check if the request is authenticated."""
if not secrets.compare_digest(authorization, settings.server.auth.secret):
if not secrets.compare_digest(authorization, settings().server.auth.secret):
# If the "Authorization" header is not the expected one, raise an exception.
raise NOT_AUTHENTICATED
return True
if not settings.server.auth.enabled:
if not settings().server.auth.enabled:
logger.debug(
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
)
@@ -62,7 +62,7 @@ else:
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
) -> bool:
"""Check if the request is authenticated."""
assert settings.server.auth.enabled
assert settings().server.auth.enabled
if not _simple_authentication:
raise NOT_AUTHENTICATED
return True