Added home.py file for /chat route

This commit is contained in:
Saurab-Shrestha 2024-01-25 17:37:00 +05:45
parent bcc3f03e25
commit 55565fd3a7
20 changed files with 295 additions and 179 deletions

0
Celery.pdf Normal file
View File

View File

@ -1,3 +1,4 @@
from pathlib import Path
PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]
UPLOAD_DIR = rf"F:\LLM\privateGPT\private_gpt\uploads"

218
private_gpt/home.py Normal file
View File

@ -0,0 +1,218 @@
"""This file should be imported only and only if you want to run the UI locally."""
from fastapi import Request
from fastapi.responses import StreamingResponse
import itertools
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Any, List
from fastapi import APIRouter, Depends, Request, FastAPI, Body
from fastapi.responses import JSONResponse
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.di import global_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg
from private_gpt.ui.common import Source
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
# Should be "private_gpt/ui/avatar-bot.ico"
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "\n\n Sources: \n"
MODES = ["Query Docs", "Search in Docs", "LLM Chat"]
home_router = APIRouter(prefix="/v1")
@singleton
class Home:
@inject
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
self._chunks_service = chunks_service
# Initialize system prompt based on default mode
self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode)
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f'<a href="{source.page_link}" target="_blank" rel="noopener noreferrer">{index}. {source.file} (page {source.page})</a>'
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
print(full_response)
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list(
itertools.chain(
*[
[
ChatMessage(
content=interaction[0], role=MessageRole.USER),
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(
SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
]
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
# If a system prompt is set, add it as a system message
if self._system_prompt:
all_messages.insert(
0,
ChatMessage(
content=self._system_prompt,
role=MessageRole.SYSTEM,
),
)
match mode:
case "Query Docs":
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
)
yield from yield_deltas(query_stream)
case "LLM Chat":
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case "Search in Docs":
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)
sources = Source.curate_sources(response)
yield "\n\n\n".join(
f"{index}. **{source.file} (page {source.page})**\n"
f" (link: [{source.page_link}]({source.page_link}))\n{source.text}"
for index, source in enumerate(sources, start=1)
)
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@staticmethod
def _get_default_system_prompt(mode: str) -> str:
p = ""
match mode:
# For query chat mode, obtain default system prompt from settings
case "Query Docs":
p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings
case "LLM Chat":
p = settings().ui.default_chat_system_prompt
# For any other mode, clear the system prompt
case _:
p = ""
return p
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _set_current_mode(self, mode: str) -> Any:
self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode))
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
)
files.add(file_name)
return [[row] for row in files]
def _upload_file(self, files: list[str]) -> None:
logger.debug("Loading count=%s files", len(files))
paths = [Path(file) for file in files]
self._ingest_service.bulk_ingest(
[(str(path.name), path) for path in paths])
import json
DEFAULT_MODE = MODES[0]
@home_router.post("/chat")
async def chat_endpoint(request: Request, message: str = Body(...), mode: str = Body(DEFAULT_MODE)):
home_instance = request.state.injector.get(Home)
history = []
print("The message is: ", message)
print("The mode is: ", mode)
responses = home_instance._chat(message, history, mode)
return StreamingResponse(content=responses, media_type='text/event-stream')
# text = (
# "To run the Celery worker based on the provided context, you can follow these steps: "
# "1. First, make sure you have Celery installed in your project."
# "2. Create a Celery instance and configure it with your settings."
# "3. Define Celery tasks that will be executed by the worker."
# "4. Start the Celery worker using the configured Celery instance."
# "5. Your Celery worker is now running and ready to process tasks."
# )
# import time
# async def generate_stream():
# for i in range(len(text)):
# yield text[:i+1] # Sending part of the text in each iteration
# time.sleep(0.1) # Simulating some processing time
# Return the responses as a StreamingResponse
# return StreamingResponse(content=responses, media_type="application/json")

View File

@ -14,7 +14,7 @@ from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.users.api.v1.api import api_router
from private_gpt.settings.settings import Settings
from private_gpt.home import home_router
logger = logging.getLogger(__name__)
@ -25,28 +25,25 @@ def create_app(root_injector: Injector) -> FastAPI:
request.state.injector = root_injector
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
# app.include_router(completions_router)
# app.include_router(chat_router)
# app.include_router(chunks_router)
# app.include_router(ingest_router)
# app.include_router(embeddings_router)
# app.include_router(health_router)
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(embeddings_router)
app.include_router(health_router)
app.include_router(api_router)
app.include_router(home_router)
settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
# if settings.server.cors.enabled/:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_origins=["http://localhost:5173", "http://localhost:8001"],
allow_methods=["DELETE", "GET", "POST", "PUT", "OPTIONS"],
allow_headers=["*"],
)
# if settings.ui.enabled:
# logger.debug("Importing the UI module")

View File

@ -106,3 +106,4 @@ def chat_completion(
return to_openai_response(
completion.response, completion.sources if body.include_sources else None
)

View File

@ -1,15 +1,17 @@
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.ingest.model import IngestedDoc
from private_gpt.server.utils.auth import authenticated
from private_gpt.constants import UPLOAD_DIR
from pathlib import Path
ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class IngestTextBody(BaseModel):
file_name: str = Field(examples=["Avatar: The Last Airbender"])
text: str = Field(
@ -38,7 +40,7 @@ def ingest(request: Request, file: UploadFile) -> IngestResponse:
@ingest_router.post("/ingest/file", tags=["Ingestion"])
def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
def ingest_file(request: Request, file: UploadFile = File(...)) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in
@ -54,12 +56,22 @@ def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs.
"""
# try:
service = request.state.injector.get(IngestService)
if file.filename is None:
raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest_bin_data(file.filename, file.file)
upload_path = Path(f"{UPLOAD_DIR}/{file.filename}")
try:
with open(upload_path, "wb") as f:
f.write(file.file.read())
with open(upload_path, "rb") as f:
ingested_documents = service.ingest_bin_data(file.filename, f)
except Exception as e:
return {"message": f"There was an error uploading the file(s)\n {e}"}
finally:
file.file.close()
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.post("/ingest/text", tags=["Ingestion"])
def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse:
@ -102,3 +114,14 @@ def delete_ingested(request: Request, doc_id: str) -> None:
"""
service = request.state.injector.get(IngestService)
service.delete(doc_id)
@ingest_router.delete("/ingest", tags=["Ingestion"])
def delete_ingested(request: Request, doc_id: str) -> None:
"""Delete the specified ingested Document.
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
The document will be effectively deleted from your storage context.
"""
service = request.state.injector.get(IngestService)
service.delete(doc_id)

View File

@ -130,3 +130,4 @@ class IngestService:
"Deleting the ingested document=%s in the doc and index store", doc_id
)
self.ingest_component.delete(doc_id)

View File

@ -1,8 +1,9 @@
from typing import Any, Literal
import os
from llama_index import Document
from pydantic import BaseModel, Field
from private_gpt.constants import UPLOAD_DIR
from pathlib import Path
class IngestedDoc(BaseModel):
object: Literal["ingest.document"]
@ -30,3 +31,4 @@ class IngestedDoc(BaseModel):
doc_id=document.doc_id,
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
)

View File

@ -33,6 +33,8 @@ SOURCES_SEPARATOR = "\n\n Sources: \n"
MODES = ["Query Docs", "Search in Docs", "LLM Chat"]
# generate
@singleton
class PrivateAdminGptUi:
@inject
@ -67,11 +69,16 @@ class PrivateAdminGptUi:
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
# sources_text = "\n\n\n".join(
# f"{index}. {source.file} (page {source.page}) (page_link {source.page_link})"
# for index, source in enumerate(cur_sources, start=1)
# )
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
f'<a href="#" target="_blank" rel="noopener noreferrer">{index}. {source.file} (page {source.page})</a>'
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
print(full_response)
yield full_response
def build_history() -> list[ChatMessage]:
@ -125,11 +132,10 @@ class PrivateAdminGptUi:
)
sources = Source.curate_sources(response)
yield "\n\n\n".join(
f"{index}. **{source.file} "
f"(page {source.page})**\n "
f"{source.text}"
f"{index}. **{source.file} (page {source.page})**\n"
f" (link: [{source.page_link}]({source.page_link}))\n{source.text}"
for index, source in enumerate(sources, start=1)
)

View File

@ -1,10 +1,12 @@
from pydantic import BaseModel
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.constants import UPLOAD_DIR
from pathlib import Path
class Source(BaseModel):
file: str
page: str
text: str
page_link: str
class Config:
frozen = True
@ -18,8 +20,9 @@ class Source(BaseModel):
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
page_link = str(Path(f"{UPLOAD_DIR}/{file_name}#page={page_label}"))
source = Source(file=file_name, page=page_label, text=chunk.text)
source = Source(file=file_name, page=page_label, text=chunk.text, page_link=page_link)
curated_sources.add(source)
return curated_sources

View File

@ -69,7 +69,7 @@ class PrivateGptUi:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source.file} (page {source.page})"
f"{index}. {source.file} (page {source.page}) (page_link {source.page_link})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
@ -130,6 +130,7 @@ class PrivateGptUi:
yield "\n\n\n".join(
f"{index}. **{source.file} "
f"(page {source.page})**\n "
f"(link {source.page_link})**\n "
f"{source.text}"
for index, source in enumerate(sources, start=1)
)

View File

@ -1,141 +0,0 @@
"""This file should be imported only and only if you want to run the UI locally."""
import itertools
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Any
import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.di import global_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
# Should be "private_gpt/ui/avatar-bot.ico"
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "\n\n Sources: \n"
MODES = ["Query Docs", "Search in Docs", "LLM Chat"]
from private_gpt.ui.common import PrivateGpt
@singleton
class UsersUI(PrivateGpt):
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
) -> None:
super().__init__(ingest_service, chat_service, chunks_service)
def _build_ui_blocks(self) -> gr.Blocks:
logger.debug("Creating the UI blocks")
with gr.Blocks(
title=UI_TAB_TITLE,
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }"
".contain { display: flex !important; flex-direction: column !important; }"
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
"#col { height: calc(100vh - 112px - 16px) !important; }",
) as users:
# with gr.Row():
# gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
mode = gr.Radio(
MODES,
label="Mode",
value="Query Docs",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
interactive=False,
render=False, # Rendered under the button
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
system_prompt_input = gr.Textbox(
placeholder=self._system_prompt,
label="System Prompt",
lines=2,
interactive=True,
render=False,
)
# When mode changes, set default system prompt
mode.change(
self._set_current_mode, inputs=mode, outputs=system_prompt_input
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(
self._set_system_prompt,
inputs=system_prompt_input,
)
with gr.Column(scale=7, elem_id="col"):
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=f"LLM: {settings().llm.mode}",
show_copy_button=True,
elem_id="chatbot",
render=False,
# avatar_images=(
# None,
# AVATAR_BOT,
# ),
),
additional_inputs=[mode, system_prompt_input],
)
return users
def get_ui_blocks(self) -> gr.Blocks:
if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI, path: str) -> None:
logger.info("PATH---------------------------->:%s", path)
blocks = self.get_ui_blocks()
blocks.queue()
logger.info("Mounting the regular gradio UI at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path)
if __name__ == "__main__":
ui = global_injector.get(UsersUI)
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -54,10 +54,12 @@ def create_token_payload(user: models.User, user_role: models.UserRole) -> dict:
"""
return {
"id": str(user.id),
"email": str(user.email),
"role": user_role.role.name,
"company_id": user_role.company.id if user_role.company else None,
}
@router.post("/login", response_model=schemas.TokenSchema)
def login_access_token(
db: Session = Depends(deps.get_db),
@ -97,6 +99,7 @@ def login_access_token(
token_payload = {
"id": str(user.id),
"email": str(user.email),
"role": role,
"company_id": company_id,
}

View File

@ -24,7 +24,7 @@ def list_companies(
),
) -> List[schemas.Company]:
"""
List companies
Retrieve a list of companies with pagination support.
"""
companies = crud.company.get_multi(db, skip=skip, limit=limit)
return companies

View File

@ -5,8 +5,9 @@ server:
env_name: ${APP_ENV:prod}
port: ${PORT:8001}
cors:
enabled: false
allow_origins: ["*"]
enabled: true
allow_credentials: true
allow_origins: ["http://localhost:5173/"]
allow_methods: ["*"]
allow_headers: ["*"]
auth: