Curate sources to avoid the UI crashing (#1212)

* Curate sources to avoid the UI crashing

* Remove sources from chat history to avoid confusing the LLM
This commit is contained in:
Iván Martínez 2023-11-12 10:59:51 +01:00 committed by GitHub
parent a579c9bdc5
commit b7647542f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -9,18 +9,43 @@ import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from pydantic import BaseModel
from private_gpt.di import root_injector
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import ChunksService
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg
logger = logging.getLogger(__name__)
UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "\n\n Sources: \n"
class Source(BaseModel):
file: str
page: str
text: str
class Config:
frozen = True
@staticmethod
def curate_sources(sources: list[Chunk]) -> set["Source"]:
curated_sources = set()
for chunk in sources:
doc_metadata = chunk.document.doc_metadata
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
source = Source(file=file_name, page=page_label, text=chunk.text)
curated_sources.add(source)
return curated_sources
class PrivateGptUi:
@ -44,21 +69,11 @@ class PrivateGptUi:
yield full_response
if completion_gen.sources:
full_response += "\n\n Sources: \n"
sources = (
{
"file": chunk.document.doc_metadata["file_name"]
if chunk.document.doc_metadata
else "",
"page": chunk.document.doc_metadata["page_label"]
if chunk.document.doc_metadata
else "",
}
for chunk in completion_gen.sources
)
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n".join(
f"{index}. {source['file']} (page {source['page']})"
for index, source in enumerate(sources, start=1)
f"{index}. {source.file} (page {source.page})"
for index, source in enumerate(cur_sources, start=1)
)
full_response += sources_text
yield full_response
@ -70,7 +85,9 @@ class PrivateGptUi:
[
ChatMessage(content=interaction[0], role=MessageRole.USER),
ChatMessage(
content=interaction[1], role=MessageRole.ASSISTANT
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
),
]
for interaction in history
@ -103,11 +120,13 @@ class PrivateGptUi:
text=message, limit=4, prev_next_chunks=0
)
sources = Source.curate_sources(response)
yield "\n\n\n".join(
f"{index}. **{chunk.document.doc_metadata['file_name'] if chunk.document.doc_metadata else ''} "
f"(page {chunk.document.doc_metadata['page_label'] if chunk.document.doc_metadata else ''})**\n "
f"{chunk.text}"
for index, chunk in enumerate(response, start=1)
f"{index}. **{source.file} "
f"(page {source.page})**\n "
f"{source.text}"
for index, source in enumerate(sources, start=1)
)
def _list_ingested_files(self) -> list[list[str]]: