mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-22 11:37:18 +00:00
fix: Remove global state (#1216)
* Remove all global settings state * chore: remove autogenerated class * chore: cleanup * chore: merge conflicts
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.open_ai.openai_models import (
|
||||
OpenAICompletion,
|
||||
@@ -52,7 +51,9 @@ class ChatBody(BaseModel):
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
)
|
||||
def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
|
||||
def chat_completion(
|
||||
request: Request, body: ChatBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""Given a list of messages comprising a conversation, return a response.
|
||||
|
||||
If `use_context` is set to `true`, the model will use context coming
|
||||
@@ -72,7 +73,7 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
|
||||
"finish_reason":null}]}
|
||||
```
|
||||
"""
|
||||
service = root_injector.get(ChatService)
|
||||
service = request.state.injector.get(ChatService)
|
||||
all_messages = [
|
||||
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
|
||||
]
|
||||
|
@@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
@@ -25,7 +24,7 @@ class ChunksResponse(BaseModel):
|
||||
|
||||
|
||||
@chunks_router.post("/chunks", tags=["Context Chunks"])
|
||||
def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
|
||||
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
|
||||
"""Given a `text`, returns the most relevant chunks from the ingested documents.
|
||||
|
||||
The returned information can be used to generate prompts that can be
|
||||
@@ -45,7 +44,7 @@ def chunks_retrieval(body: ChunksBody) -> ChunksResponse:
|
||||
`/ingest/list` endpoint. If you want all ingested documents to be used,
|
||||
remove `context_filter` altogether.
|
||||
"""
|
||||
service = root_injector.get(ChunksService)
|
||||
service = request.state.injector.get(ChunksService)
|
||||
results = service.retrieve_relevant(
|
||||
body.text, body.context_filter, body.limit, body.prev_next_chunks
|
||||
)
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@@ -41,7 +41,9 @@ class CompletionsBody(BaseModel):
|
||||
responses={200: {"model": OpenAICompletion}},
|
||||
tags=["Contextual Completions"],
|
||||
)
|
||||
def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResponse:
|
||||
def prompt_completion(
|
||||
request: Request, body: CompletionsBody
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""We recommend most users use our Chat completions API.
|
||||
|
||||
Given a prompt, the model will return one predicted completion. If `use_context`
|
||||
@@ -70,4 +72,4 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
|
||||
include_sources=body.include_sources,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return chat_completion(chat_body)
|
||||
return chat_completion(request, chat_body)
|
||||
|
@@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.server.embeddings.embeddings_service import (
|
||||
Embedding,
|
||||
EmbeddingsService,
|
||||
@@ -24,13 +23,13 @@ class EmbeddingsResponse(BaseModel):
|
||||
|
||||
|
||||
@embeddings_router.post("/embeddings", tags=["Embeddings"])
|
||||
def embeddings_generation(body: EmbeddingsBody) -> EmbeddingsResponse:
|
||||
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
|
||||
"""Get a vector representation of a given input.
|
||||
|
||||
That vector representation can be easily consumed
|
||||
by machine learning models and algorithms.
|
||||
"""
|
||||
service = root_injector.get(EmbeddingsService)
|
||||
service = request.state.injector.get(EmbeddingsService)
|
||||
input_texts = body.input if isinstance(body.input, list) else [body.input]
|
||||
embeddings = service.texts_embeddings(input_texts)
|
||||
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
|
||||
|
@@ -1,9 +1,8 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
from private_gpt.di import root_injector
|
||||
from private_gpt.server.ingest.ingest_service import IngestedDoc, IngestService
|
||||
from private_gpt.server.utils.auth import authenticated
|
||||
|
||||
@@ -17,7 +16,7 @@ class IngestResponse(BaseModel):
|
||||
|
||||
|
||||
@ingest_router.post("/ingest", tags=["Ingestion"])
|
||||
def ingest(file: UploadFile) -> IngestResponse:
|
||||
def ingest(request: Request, file: UploadFile) -> IngestResponse:
|
||||
"""Ingests and processes a file, storing its chunks to be used as context.
|
||||
|
||||
The context obtained from files is later used in
|
||||
@@ -33,7 +32,7 @@ def ingest(file: UploadFile) -> IngestResponse:
|
||||
can be used to filter the context used to create responses in
|
||||
`/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
if file.filename is None:
|
||||
raise HTTPException(400, "No file name provided")
|
||||
ingested_documents = service.ingest(file.filename, file.file.read())
|
||||
@@ -41,23 +40,23 @@ def ingest(file: UploadFile) -> IngestResponse:
|
||||
|
||||
|
||||
@ingest_router.get("/ingest/list", tags=["Ingestion"])
|
||||
def list_ingested() -> IngestResponse:
|
||||
def list_ingested(request: Request) -> IngestResponse:
|
||||
"""Lists already ingested Documents including their Document ID and metadata.
|
||||
|
||||
Those IDs can be used to filter the context used to create responses
|
||||
in `/chat/completions`, `/completions`, and `/chunks` APIs.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
ingested_documents = service.list_ingested()
|
||||
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
|
||||
|
||||
|
||||
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
|
||||
def delete_ingested(doc_id: str) -> None:
|
||||
def delete_ingested(request: Request, doc_id: str) -> None:
|
||||
"""Delete the specified ingested Document.
|
||||
|
||||
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
|
||||
The document will be effectively deleted from your storage context.
|
||||
"""
|
||||
service = root_injector.get(IngestService)
|
||||
service = request.state.injector.get(IngestService)
|
||||
service.delete(doc_id)
|
||||
|
@@ -38,13 +38,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
if not secrets.compare_digest(authorization, settings.server.auth.secret):
|
||||
if not secrets.compare_digest(authorization, settings().server.auth.secret):
|
||||
# If the "Authorization" header is not the expected one, raise an exception.
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
|
||||
|
||||
if not settings.server.auth.enabled:
|
||||
if not settings().server.auth.enabled:
|
||||
logger.debug(
|
||||
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
|
||||
)
|
||||
@@ -62,7 +62,7 @@ else:
|
||||
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
|
||||
) -> bool:
|
||||
"""Check if the request is authenticated."""
|
||||
assert settings.server.auth.enabled
|
||||
assert settings().server.auth.enabled
|
||||
if not _simple_authentication:
|
||||
raise NOT_AUTHENTICATED
|
||||
return True
|
||||
|
Reference in New Issue
Block a user