From b7647542f4a09c42caf2e1e5cc4942af5475a2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Mart=C3=ADnez?= Date: Sun, 12 Nov 2023 10:59:51 +0100 Subject: [PATCH] Curate sources to avoid the UI crashing (#1212) * Curate sources to avoid the UI crashing * Remove sources from chat history to avoid confusing the LLM --- private_gpt/ui/ui.py | 61 +++++++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index d9d96d91..f4a1431a 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -9,18 +9,43 @@ import gradio as gr # type: ignore from fastapi import FastAPI from gradio.themes.utils.colors import slate # type: ignore from llama_index.llms import ChatMessage, ChatResponse, MessageRole +from pydantic import BaseModel from private_gpt.di import root_injector from private_gpt.server.chat.chat_service import ChatService, CompletionGen -from private_gpt.server.chunks.chunks_service import ChunksService +from private_gpt.server.chunks.chunks_service import Chunk, ChunksService from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.settings.settings import settings from private_gpt.ui.images import logo_svg logger = logging.getLogger(__name__) - UI_TAB_TITLE = "My Private GPT" +SOURCES_SEPARATOR = "\n\n Sources: \n" + + +class Source(BaseModel): + file: str + page: str + text: str + + class Config: + frozen = True + + @staticmethod + def curate_sources(sources: list[Chunk]) -> set["Source"]: + curated_sources = set() + + for chunk in sources: + doc_metadata = chunk.document.doc_metadata + + file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-" + page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-" + + source = Source(file=file_name, page=page_label, text=chunk.text) + curated_sources.add(source) + + return curated_sources class PrivateGptUi: @@ -44,21 +69,11 @@ class PrivateGptUi: yield full_response if completion_gen.sources: - full_response += "\n\n Sources: \n" - sources = ( - { - "file": chunk.document.doc_metadata["file_name"] - if chunk.document.doc_metadata - else "", - "page": chunk.document.doc_metadata["page_label"] - if chunk.document.doc_metadata - else "", - } - for chunk in completion_gen.sources - ) + full_response += SOURCES_SEPARATOR + cur_sources = Source.curate_sources(completion_gen.sources) sources_text = "\n\n\n".join( - f"{index}. {source['file']} (page {source['page']})" - for index, source in enumerate(sources, start=1) + f"{index}. {source.file} (page {source.page})" + for index, source in enumerate(cur_sources, start=1) ) full_response += sources_text yield full_response @@ -70,7 +85,9 @@ class PrivateGptUi: [ ChatMessage(content=interaction[0], role=MessageRole.USER), ChatMessage( - content=interaction[1], role=MessageRole.ASSISTANT + # Remove from history content the Sources information + content=interaction[1].split(SOURCES_SEPARATOR)[0], + role=MessageRole.ASSISTANT, ), ] for interaction in history @@ -103,11 +120,13 @@ class PrivateGptUi: text=message, limit=4, prev_next_chunks=0 ) + sources = Source.curate_sources(response) + yield "\n\n\n".join( - f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} " - f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n " - f"{chunk.text}" - for index, chunk in enumerate(response, start=1) + f"{index}. **{source.file} " + f"(page {source.page})**\n " + f"{source.text}" + for index, source in enumerate(sources, start=1) ) def _list_ingested_files(self) -> list[list[str]]: