Add sources to completions APIs and UI (#1206)

This commit is contained in:
Iván Martínez 2023-11-11 21:39:15 +01:00 committed by GitHub
parent dbd99e7b4b
commit a22969ad1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 159 additions and 70 deletions

File diff suppressed because one or more lines are too long

View File

@ -5,6 +5,8 @@ from collections.abc import Iterator
from llama_index.llms import ChatResponse, CompletionResponse from llama_index.llms import ChatResponse, CompletionResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from private_gpt.server.chunks.chunks_service import Chunk
class OpenAIDelta(BaseModel): class OpenAIDelta(BaseModel):
"""A piece of completion that needs to be concatenated to get the full message.""" """A piece of completion that needs to be concatenated to get the full message."""
@ -27,11 +29,13 @@ class OpenAIChoice(BaseModel):
"""Response from AI. """Response from AI.
Either the delta or the message will be present, but never both. Either the delta or the message will be present, but never both.
Sources used will be returned in case context retrieval was enabled.
""" """
finish_reason: str | None = Field(examples=["stop"]) finish_reason: str | None = Field(examples=["stop"])
delta: OpenAIDelta | None = None delta: OpenAIDelta | None = None
message: OpenAIMessage | None = None message: OpenAIMessage | None = None
sources: list[Chunk] | None = None
index: int = 0 index: int = 0
@ -49,7 +53,10 @@ class OpenAICompletion(BaseModel):
@classmethod @classmethod
def from_text( def from_text(
cls, text: str | None, finish_reason: str | None = None cls,
text: str | None,
finish_reason: str | None = None,
sources: list[Chunk] | None = None,
) -> "OpenAICompletion": ) -> "OpenAICompletion":
return OpenAICompletion( return OpenAICompletion(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -60,13 +67,18 @@ class OpenAICompletion(BaseModel):
OpenAIChoice( OpenAIChoice(
message=OpenAIMessage(role="assistant", content=text), message=OpenAIMessage(role="assistant", content=text),
finish_reason=finish_reason, finish_reason=finish_reason,
sources=sources,
) )
], ],
) )
@classmethod @classmethod
def json_from_delta( def json_from_delta(
cls, *, text: str | None, finish_reason: str | None = None cls,
*,
text: str | None,
finish_reason: str | None = None,
sources: list[Chunk] | None = None,
) -> str: ) -> str:
chunk = OpenAICompletion( chunk = OpenAICompletion(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
@ -77,6 +89,7 @@ class OpenAICompletion(BaseModel):
OpenAIChoice( OpenAIChoice(
delta=OpenAIDelta(content=text), delta=OpenAIDelta(content=text),
finish_reason=finish_reason, finish_reason=finish_reason,
sources=sources,
) )
], ],
) )
@ -84,20 +97,25 @@ class OpenAICompletion(BaseModel):
return chunk.model_dump_json() return chunk.model_dump_json()
def to_openai_response(response: str | ChatResponse) -> OpenAICompletion: def to_openai_response(
response: str | ChatResponse, sources: list[Chunk] | None = None
) -> OpenAICompletion:
if isinstance(response, ChatResponse): if isinstance(response, ChatResponse):
return OpenAICompletion.from_text(response.delta, finish_reason="stop") return OpenAICompletion.from_text(response.delta, finish_reason="stop")
else: else:
return OpenAICompletion.from_text(response, finish_reason="stop") return OpenAICompletion.from_text(
response, finish_reason="stop", sources=sources
)
def to_openai_sse_stream( def to_openai_sse_stream(
response_generator: Iterator[str | CompletionResponse | ChatResponse], response_generator: Iterator[str | CompletionResponse | ChatResponse],
sources: list[Chunk] | None = None,
) -> Iterator[str]: ) -> Iterator[str]:
for response in response_generator: for response in response_generator:
if isinstance(response, CompletionResponse | ChatResponse): if isinstance(response, CompletionResponse | ChatResponse):
yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n" yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n"
else: else:
yield f"data: {OpenAICompletion.json_from_delta(text=response)}\n\n" yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n"
yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n" yield f"data: {OpenAICompletion.json_from_delta(text=None, finish_reason='stop')}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"

View File

@ -20,6 +20,7 @@ class ChatBody(BaseModel):
messages: list[OpenAIMessage] messages: list[OpenAIMessage]
use_context: bool = False use_context: bool = False
context_filter: ContextFilter | None = None context_filter: ContextFilter | None = None
include_sources: bool = True
stream: bool = False stream: bool = False
model_config = { model_config = {
@ -34,6 +35,7 @@ class ChatBody(BaseModel):
], ],
"stream": False, "stream": False,
"use_context": True, "use_context": True,
"include_sources": True,
"context_filter": { "context_filter": {
"docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"] "docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
}, },
@ -58,6 +60,9 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
Ingested documents IDs can be found using `/ingest/list` endpoint. If you want Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
all ingested documents to be used, remove `context_filter` altogether. all ingested documents to be used, remove `context_filter` altogether.
When using `'include_sources': true`, the API will return the source Chunks used
to create the response, which come from the context provided.
When using `'stream': true`, the API will return data chunks following [OpenAI's When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
``` ```
@ -71,12 +76,18 @@ def chat_completion(body: ChatBody) -> OpenAICompletion | StreamingResponse:
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
] ]
if body.stream: if body.stream:
stream = service.stream_chat( completion_gen = service.stream_chat(
all_messages, body.use_context, body.context_filter all_messages, body.use_context, body.context_filter
) )
return StreamingResponse( return StreamingResponse(
to_openai_sse_stream(stream), media_type="text/event-stream" to_openai_sse_stream(
completion_gen.response,
completion_gen.sources if body.include_sources else None,
),
media_type="text/event-stream",
) )
else: else:
response = service.chat(all_messages, body.use_context, body.context_filter) completion = service.chat(all_messages, body.use_context, body.context_filter)
return to_openai_response(response) return to_openai_response(
completion.response, completion.sources if body.include_sources else None
)

View File

@ -1,13 +1,14 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from injector import inject, singleton from injector import inject, singleton
from llama_index import ServiceContext, StorageContext, VectorStoreIndex from llama_index import ServiceContext, StorageContext, VectorStoreIndex
from llama_index.chat_engine import ContextChatEngine from llama_index.chat_engine import ContextChatEngine
from llama_index.chat_engine.types import (
BaseChatEngine,
)
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.llm_predictor.utils import stream_chat_response_to_tokens from llama_index.llm_predictor.utils import stream_chat_response_to_tokens
from llama_index.llms import ChatMessage from llama_index.llms import ChatMessage
from llama_index.types import TokenGen from llama_index.types import TokenGen
from pydantic import BaseModel
from private_gpt.components.embedding.embedding_component import EmbeddingComponent from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent from private_gpt.components.llm.llm_component import LLMComponent
@ -16,12 +17,17 @@ from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent, VectorStoreComponent,
) )
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
if TYPE_CHECKING:
from llama_index.chat_engine.types import ( class Completion(BaseModel):
AgentChatResponse, response: str
StreamingAgentChatResponse, sources: list[Chunk] | None = None
)
class CompletionGen(BaseModel):
response: TokenGen
sources: list[Chunk] | None = None
@singleton @singleton
@ -51,66 +57,64 @@ class ChatService:
show_progress=True, show_progress=True,
) )
def _chat_with_contex( def _chat_engine(
self, self, context_filter: ContextFilter | None = None
message: str, ) -> BaseChatEngine:
context_filter: ContextFilter | None = None,
chat_history: Sequence[ChatMessage] | None = None,
streaming: bool = False,
) -> Any:
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, context_filter=context_filter index=self.index, context_filter=context_filter
) )
chat_engine = ContextChatEngine.from_defaults( return ContextChatEngine.from_defaults(
retriever=vector_index_retriever, retriever=vector_index_retriever,
service_context=self.service_context, service_context=self.service_context,
node_postprocessors=[ node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"), MetadataReplacementPostProcessor(target_metadata_key="window"),
], ],
) )
if streaming:
result = chat_engine.stream_chat(message, chat_history)
else:
result = chat_engine.chat(message, chat_history)
return result
def stream_chat( def stream_chat(
self, self,
messages: list[ChatMessage], messages: list[ChatMessage],
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> TokenGen: ) -> CompletionGen:
if use_context: if use_context:
last_message = messages[-1].content last_message = messages[-1].content
response: StreamingAgentChatResponse = self._chat_with_contex( chat_engine = self._chat_engine(context_filter=context_filter)
streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",
chat_history=messages[:-1], chat_history=messages[:-1],
context_filter=context_filter,
streaming=True,
) )
response_gen = response.response_gen sources = [
Chunk.from_node(node) for node in streaming_response.source_nodes
]
completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources
)
else: else:
stream = self.llm_service.llm.stream_chat(messages) stream = self.llm_service.llm.stream_chat(messages)
response_gen = stream_chat_response_to_tokens(stream) completion_gen = CompletionGen(
return response_gen response=stream_chat_response_to_tokens(stream)
)
return completion_gen
def chat( def chat(
self, self,
messages: list[ChatMessage], messages: list[ChatMessage],
use_context: bool = False, use_context: bool = False,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
) -> str: ) -> Completion:
if use_context: if use_context:
last_message = messages[-1].content last_message = messages[-1].content
wrapped_response: AgentChatResponse = self._chat_with_contex( chat_engine = self._chat_engine(context_filter=context_filter)
wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "", message=last_message if last_message is not None else "",
chat_history=messages[:-1], chat_history=messages[:-1],
context_filter=context_filter,
streaming=False,
) )
response = wrapped_response.response sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
else: else:
chat_response = self.llm_service.llm.chat(messages) chat_response = self.llm_service.llm.chat(messages)
response_content = chat_response.message.content response_content = chat_response.message.content
response = response_content if response_content is not None else "" response = response_content if response_content is not None else ""
return response completion = Completion(response=response)
return completion

View File

@ -24,15 +24,31 @@ class Chunk(BaseModel):
document: IngestedDoc document: IngestedDoc
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."]) text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
previous_texts: list[str] | None = Field( previous_texts: list[str] | None = Field(
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]] default=None,
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
) )
next_texts: list[str] | None = Field( next_texts: list[str] | None = Field(
default=None,
examples=[ examples=[
[ [
"New leads came from Google Ads campaign.", "New leads came from Google Ads campaign.",
"The campaign was run by the Marketing Department", "The campaign was run by the Marketing Department",
] ]
] ],
)
@classmethod
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
return cls(
object="context.chunk",
score=node.score or 0.0,
document=IngestedDoc(
object="ingest.document",
doc_id=doc_id,
doc_metadata=node.metadata,
),
text=node.get_content(),
) )
@ -98,22 +114,11 @@ class ChunksService:
retrieved_nodes = [] retrieved_nodes = []
for node in nodes: for node in nodes:
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-" chunk = Chunk.from_node(node)
retrieved_nodes.append( chunk.previous_texts = self._get_sibling_nodes_text(
Chunk(
object="context.chunk",
score=node.score or 0.0,
document=IngestedDoc(
object="ingest.document",
doc_id=doc_id,
doc_metadata=node.metadata,
),
text=node.get_content(),
previous_texts=self._get_sibling_nodes_text(
node, prev_next_chunks, False node, prev_next_chunks, False
),
next_texts=self._get_sibling_nodes_text(node, prev_next_chunks),
)
) )
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
retrieved_nodes.append(chunk)
return retrieved_nodes return retrieved_nodes

View File

@ -16,6 +16,7 @@ class CompletionsBody(BaseModel):
prompt: str prompt: str
use_context: bool = False use_context: bool = False
context_filter: ContextFilter | None = None context_filter: ContextFilter | None = None
include_sources: bool = True
stream: bool = False stream: bool = False
model_config = { model_config = {
@ -25,6 +26,7 @@ class CompletionsBody(BaseModel):
"prompt": "How do you fry an egg?", "prompt": "How do you fry an egg?",
"stream": False, "stream": False,
"use_context": False, "use_context": False,
"include_sources": False,
} }
] ]
} }
@ -48,6 +50,9 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
can be found using `/ingest/list` endpoint. If you want all ingested documents to can be found using `/ingest/list` endpoint. If you want all ingested documents to
be used, remove `context_filter` altogether. be used, remove `context_filter` altogether.
When using `'include_sources': true`, the API will return the source Chunks used
to create the response, which come from the context provided.
When using `'stream': true`, the API will return data chunks following [OpenAI's When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming): streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
``` ```
@ -61,6 +66,7 @@ def prompt_completion(body: CompletionsBody) -> OpenAICompletion | StreamingResp
messages=[message], messages=[message],
use_context=body.use_context, use_context=body.use_context,
stream=body.stream, stream=body.stream,
include_sources=body.include_sources,
context_filter=body.context_filter, context_filter=body.context_filter,
) )
return chat_completion(chat_body) return chat_completion(chat_body)

View File

@ -11,7 +11,7 @@ from gradio.themes.utils.colors import slate # type: ignore
from llama_index.llms import ChatMessage, ChatResponse, MessageRole from llama_index.llms import ChatMessage, ChatResponse, MessageRole
from private_gpt.di import root_injector from private_gpt.di import root_injector
from private_gpt.server.chat.chat_service import ChatService from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import ChunksService from private_gpt.server.chunks.chunks_service import ChunksService
from private_gpt.server.ingest.ingest_service import IngestService from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.settings.settings import settings from private_gpt.settings.settings import settings
@ -33,8 +33,9 @@ class PrivateGptUi:
self._ui_block = None self._ui_block = None
def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any: def _chat(self, message: str, history: list[list[str]], mode: str, *_: Any) -> Any:
def yield_deltas(stream: Iterable[ChatResponse | str]) -> Iterable[str]: def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = "" full_response: str = ""
stream = completion_gen.response
for delta in stream: for delta in stream:
if isinstance(delta, str): if isinstance(delta, str):
full_response += str(delta) full_response += str(delta)
@ -42,6 +43,26 @@ class PrivateGptUi:
full_response += delta.delta or "" full_response += delta.delta or ""
yield full_response yield full_response
if completion_gen.sources:
full_response += "\n\n Sources: \n"
sources = (
{
"file": chunk.document.doc_metadata["file_name"]
if chunk.document.doc_metadata
else "",
"page": chunk.document.doc_metadata["page_label"]
if chunk.document.doc_metadata
else "",
}
for chunk in completion_gen.sources
)
sources_text = "\n\n\n".join(
f"{index}. {source['file']} (page {source['page']})"
for index, source in enumerate(sources, start=1)
)
full_response += sources_text
yield full_response
def build_history() -> list[ChatMessage]: def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = list( history_messages: list[ChatMessage] = list(
itertools.chain( itertools.chain(