"""This file should be imported if and only if you want to run the UI locally.""" import base64 import logging import time from collections.abc import Iterable from enum import Enum from pathlib import Path from typing import Any import gradio as gr # type: ignore from fastapi import FastAPI from gradio.themes.utils.colors import slate # type: ignore from injector import inject, singleton from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole from llama_index.core.types import TokenGen from pydantic import BaseModel from private_gpt.constants import PROJECT_ROOT_PATH from private_gpt.di import global_injector from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.server.chat.chat_service import ChatService, CompletionGen from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.recipes.summarize.summarize_service import SummarizeService from private_gpt.settings.settings import settings from private_gpt.ui.images import logo_svg logger = logging.getLogger(__name__) THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) # Should be "private_gpt/ui/avatar-bot.ico" AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico" UI_TAB_TITLE = "My Private GPT" SOURCES_SEPARATOR = "
Sources: \n" class Modes(str, Enum): RAG_MODE = "RAG" SEARCH_MODE = "Search" BASIC_CHAT_MODE = "Basic" SUMMARIZE_MODE = "Summarize" MODES: list[Modes] = [ Modes.RAG_MODE, Modes.SEARCH_MODE, Modes.BASIC_CHAT_MODE, Modes.SUMMARIZE_MODE, ] class Source(BaseModel): file: str page: str text: str class Config: frozen = True @staticmethod def curate_sources(sources: list[Chunk]) -> list["Source"]: curated_sources = [] for chunk in sources: doc_metadata = chunk.document.doc_metadata file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-" page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-" source = Source(file=file_name, page=page_label, text=chunk.text) curated_sources.append(source) curated_sources = list( dict.fromkeys(curated_sources).keys() ) # Unique sources only return curated_sources @singleton class PrivateGptUi: @inject def __init__( self, ingest_service: IngestService, chat_service: ChatService, chunks_service: ChunksService, summarizeService: SummarizeService, ) -> None: self._ingest_service = ingest_service self._chat_service = chat_service self._chunks_service = chunks_service self._summarize_service = summarizeService # Cache the UI blocks self._ui_block = None self._selected_filename = None # Initialize system prompt based on default mode default_mode_map = {mode.value: mode for mode in Modes} self._default_mode = default_mode_map.get( settings().ui.default_mode, Modes.RAG_MODE ) self._system_prompt = self._get_default_system_prompt(self._default_mode) def _chat( self, message: str, history: list[list[str]], mode: Modes, *_: Any ) -> Any: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]: full_response: str = "" stream = completion_gen.response for delta in stream: if isinstance(delta, str): full_response += str(delta) elif isinstance(delta, ChatResponse): full_response += delta.delta or "" yield full_response time.sleep(0.02) if completion_gen.sources: full_response += SOURCES_SEPARATOR cur_sources = Source.curate_sources(completion_gen.sources) sources_text = "\n\n\n" used_files = set() for index, source in enumerate(cur_sources, start=1): if f"{source.file}-{source.page}" not in used_files: sources_text = ( sources_text + f"{index}. {source.file} (page {source.page}) \n\n" ) used_files.add(f"{source.file}-{source.page}") sources_text += "
\n\n" full_response += sources_text yield full_response def yield_tokens(token_gen: TokenGen) -> Iterable[str]: full_response: str = "" for token in token_gen: full_response += str(token) yield full_response def build_history() -> list[ChatMessage]: history_messages: list[ChatMessage] = [] for interaction in history: history_messages.append( ChatMessage(content=interaction[0], role=MessageRole.USER) ) if len(interaction) > 1 and interaction[1] is not None: history_messages.append( ChatMessage( # Remove from history content the Sources information content=interaction[1].split(SOURCES_SEPARATOR)[0], role=MessageRole.ASSISTANT, ) ) # max 20 messages to try to avoid context overflow return history_messages[:20] new_message = ChatMessage(content=message, role=MessageRole.USER) all_messages = [*build_history(), new_message] # If a system prompt is set, add it as a system message if self._system_prompt: all_messages.insert( 0, ChatMessage( content=self._system_prompt, role=MessageRole.SYSTEM, ), ) match mode: case Modes.RAG_MODE: # Use only the selected file for the query context_filter = None if self._selected_filename is not None: docs_ids = [] for ingested_document in self._ingest_service.list_ingested(): if ( ingested_document.doc_metadata["file_name"] == self._selected_filename ): docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) query_stream = self._chat_service.stream_chat( messages=all_messages, use_context=True, context_filter=context_filter, ) yield from yield_deltas(query_stream) case Modes.BASIC_CHAT_MODE: llm_stream = self._chat_service.stream_chat( messages=all_messages, use_context=False, ) yield from yield_deltas(llm_stream) case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( text=message, limit=4, prev_next_chunks=0 ) sources = Source.curate_sources(response) yield "\n\n\n".join( f"{index}. **{source.file} " f"(page {source.page})**\n " f"{source.text}" for index, source in enumerate(sources, start=1) ) case Modes.SUMMARIZE_MODE: # Summarize the given message, optionally using selected files context_filter = None if self._selected_filename: docs_ids = [] for ingested_document in self._ingest_service.list_ingested(): if ( ingested_document.doc_metadata["file_name"] == self._selected_filename ): docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) summary_stream = self._summarize_service.stream_summarize( use_context=True, context_filter=context_filter, instructions=message, ) yield from yield_tokens(summary_stream) # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @staticmethod def _get_default_system_prompt(mode: Modes) -> str: p = "" match mode: # For query chat mode, obtain default system prompt from settings case Modes.RAG_MODE: p = settings().ui.default_query_system_prompt # For chat mode, obtain default system prompt from settings case Modes.BASIC_CHAT_MODE: p = settings().ui.default_chat_system_prompt # For summarization mode, obtain default system prompt from settings case Modes.SUMMARIZE_MODE: p = settings().ui.default_summarization_system_prompt # For any other mode, clear the system prompt case _: p = "" return p @staticmethod def _get_default_mode_explanation(mode: Modes) -> str: match mode: case Modes.RAG_MODE: return "Get contextualized answers from selected files." case Modes.SEARCH_MODE: return "Find relevant chunks of text in selected files." case Modes.BASIC_CHAT_MODE: return "Chat with the LLM using its training data. Files are ignored." case Modes.SUMMARIZE_MODE: return "Generate a summary of the selected files. Prompt to customize the result." case _: return "" def _set_system_prompt(self, system_prompt_input: str) -> None: logger.info(f"Setting system prompt to: {system_prompt_input}") self._system_prompt = system_prompt_input def _set_explanatation_mode(self, explanation_mode: str) -> None: self._explanation_mode = explanation_mode def _set_current_mode(self, mode: Modes) -> Any: self.mode = mode self._set_system_prompt(self._get_default_system_prompt(mode)) self._set_explanatation_mode(self._get_default_mode_explanation(mode)) interactive = self._system_prompt is not None return [ gr.update(placeholder=self._system_prompt, interactive=interactive), gr.update(value=self._explanation_mode), ] def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): if ingested_document.doc_metadata is None: # Skipping documents without metadata continue file_name = ingested_document.doc_metadata.get( "file_name", "[FILE NAME MISSING]" ) files.add(file_name) return [[row] for row in files] def _upload_file(self, files: list[str]) -> None: logger.debug("Loading count=%s files", len(files)) paths = [Path(file) for file in files] # remove all existing Documents with name identical to a new file upload: file_names = [path.name for path in paths] doc_ids_to_delete = [] for ingested_document in self._ingest_service.list_ingested(): if ( ingested_document.doc_metadata and ingested_document.doc_metadata["file_name"] in file_names ): doc_ids_to_delete.append(ingested_document.doc_id) if len(doc_ids_to_delete) > 0: logger.info( "Uploading file(s) which were already ingested: %s document(s) will be replaced.", len(doc_ids_to_delete), ) for doc_id in doc_ids_to_delete: self._ingest_service.delete(doc_id) self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) def _delete_all_files(self) -> Any: ingested_files = self._ingest_service.list_ingested() logger.debug("Deleting count=%s files", len(ingested_files)) for ingested_document in ingested_files: self._ingest_service.delete(ingested_document.doc_id) return [ gr.List(self._list_ingested_files()), gr.components.Button(interactive=False), gr.components.Button(interactive=False), gr.components.Textbox("All files"), ] def _delete_selected_file(self) -> Any: logger.debug("Deleting selected %s", self._selected_filename) # Note: keep looping for pdf's (each page became a Document) for ingested_document in self._ingest_service.list_ingested(): if ( ingested_document.doc_metadata and ingested_document.doc_metadata["file_name"] == self._selected_filename ): self._ingest_service.delete(ingested_document.doc_id) return [ gr.List(self._list_ingested_files()), gr.components.Button(interactive=False), gr.components.Button(interactive=False), gr.components.Textbox("All files"), ] def _deselect_selected_file(self) -> Any: self._selected_filename = None return [ gr.components.Button(interactive=False), gr.components.Button(interactive=False), gr.components.Textbox("All files"), ] def _selected_a_file(self, select_data: gr.SelectData) -> Any: self._selected_filename = select_data.value return [ gr.components.Button(interactive=True), gr.components.Button(interactive=True), gr.components.Textbox(self._selected_filename), ] def _build_ui_blocks(self) -> gr.Blocks: logger.debug("Creating the UI blocks") with gr.Blocks( title=UI_TAB_TITLE, theme=gr.themes.Soft(primary_hue=slate), css=".logo { " "display:flex;" "background-color: #C7BAFF;" "height: 80px;" "border-radius: 8px;" "align-content: center;" "justify-content: center;" "align-items: center;" "}" ".logo img { height: 25% }" ".contain { display: flex !important; flex-direction: column !important; }" "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" "#col { height: calc(100vh - 112px - 16px) !important; }" "hr { margin-top: 1em; margin-bottom: 1em; border: 0; border-top: 1px solid #FFF; }" ".avatar-image { background-color: antiquewhite; border-radius: 2px; }" ".footer { text-align: center; margin-top: 20px; font-size: 14px; display: flex; align-items: center; justify-content: center; }" ".footer-zylon-link { display:flex; margin-left: 5px; text-decoration: auto; color: var(--body-text-color); }" ".footer-zylon-link:hover { color: #C7BAFF; }" ".footer-zylon-ico { height: 20px; margin-left: 5px; background-color: antiquewhite; border-radius: 2px; }", ) as blocks: with gr.Row(): gr.HTML(f"