mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-08 10:39:33 +00:00
Allow passing a system prompt (#1318)
This commit is contained in:
@@ -28,10 +28,14 @@ class ChatBody(BaseModel):
|
||||
"examples": [
|
||||
{
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a rapper. Always answer with a rap.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "How do you fry an egg?",
|
||||
}
|
||||
},
|
||||
],
|
||||
"stream": False,
|
||||
"use_context": True,
|
||||
@@ -56,6 +60,9 @@ def chat_completion(
|
||||
) -> OpenAICompletion | StreamingResponse:
|
||||
"""Given a list of messages comprising a conversation, return a response.
|
||||
|
||||
Optionally include an initial `role: system` message to influence the way
|
||||
the LLM answers.
|
||||
|
||||
If `use_context` is set to `true`, the model will use context coming
|
||||
from the ingested documents to create the response. The documents being used can
|
||||
be filtered using the `context_filter` and passing the document IDs to be used.
|
||||
@@ -79,7 +86,9 @@ def chat_completion(
|
||||
]
|
||||
if body.stream:
|
||||
completion_gen = service.stream_chat(
|
||||
all_messages, body.use_context, body.context_filter
|
||||
messages=all_messages,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return StreamingResponse(
|
||||
to_openai_sse_stream(
|
||||
@@ -89,7 +98,11 @@ def chat_completion(
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
completion = service.chat(all_messages, body.use_context, body.context_filter)
|
||||
completion = service.chat(
|
||||
messages=all_messages,
|
||||
use_context=body.use_context,
|
||||
context_filter=body.context_filter,
|
||||
)
|
||||
return to_openai_response(
|
||||
completion.response, completion.sources if body.include_sources else None
|
||||
)
|
||||
|
@@ -1,12 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index import ServiceContext, StorageContext, VectorStoreIndex
|
||||
from llama_index.chat_engine import ContextChatEngine
|
||||
from llama_index.chat_engine import ContextChatEngine, SimpleChatEngine
|
||||
from llama_index.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
)
|
||||
from llama_index.indices.postprocessor import MetadataReplacementPostProcessor
|
||||
from llama_index.llm_predictor.utils import stream_chat_response_to_tokens
|
||||
from llama_index.llms import ChatMessage
|
||||
from llama_index.llms import ChatMessage, MessageRole
|
||||
from llama_index.types import TokenGen
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -30,6 +31,40 @@ class CompletionGen(BaseModel):
|
||||
sources: list[Chunk] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatEngineInput:
|
||||
system_message: ChatMessage | None = None
|
||||
last_message: ChatMessage | None = None
|
||||
chat_history: list[ChatMessage] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
|
||||
# Detect if there is a system message, extract the last message and chat history
|
||||
system_message = (
|
||||
messages[0]
|
||||
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
|
||||
else None
|
||||
)
|
||||
last_message = (
|
||||
messages[-1]
|
||||
if len(messages) > 0 and messages[-1].role == MessageRole.USER
|
||||
else None
|
||||
)
|
||||
# Remove from messages list the system message and last message,
|
||||
# if they exist. The rest is the chat history.
|
||||
if system_message:
|
||||
messages.pop(0)
|
||||
if last_message:
|
||||
messages.pop(-1)
|
||||
chat_history = messages if len(messages) > 0 else None
|
||||
|
||||
return cls(
|
||||
system_message=system_message,
|
||||
last_message=last_message,
|
||||
chat_history=chat_history,
|
||||
)
|
||||
|
||||
|
||||
@singleton
|
||||
class ChatService:
|
||||
@inject
|
||||
@@ -58,18 +93,28 @@ class ChatService:
|
||||
)
|
||||
|
||||
def _chat_engine(
|
||||
self, context_filter: ContextFilter | None = None
|
||||
self,
|
||||
system_prompt: str | None = None,
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> BaseChatEngine:
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=self.index, context_filter=context_filter
|
||||
)
|
||||
return ContextChatEngine.from_defaults(
|
||||
retriever=vector_index_retriever,
|
||||
service_context=self.service_context,
|
||||
node_postprocessors=[
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
],
|
||||
)
|
||||
if use_context:
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=self.index, context_filter=context_filter
|
||||
)
|
||||
return ContextChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
retriever=vector_index_retriever,
|
||||
service_context=self.service_context,
|
||||
node_postprocessors=[
|
||||
MetadataReplacementPostProcessor(target_metadata_key="window"),
|
||||
],
|
||||
)
|
||||
else:
|
||||
return SimpleChatEngine.from_defaults(
|
||||
system_prompt=system_prompt,
|
||||
service_context=self.service_context,
|
||||
)
|
||||
|
||||
def stream_chat(
|
||||
self,
|
||||
@@ -77,24 +122,34 @@ class ChatService:
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> CompletionGen:
|
||||
if use_context:
|
||||
last_message = messages[-1].content
|
||||
chat_engine = self._chat_engine(context_filter=context_filter)
|
||||
streaming_response = chat_engine.stream_chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=messages[:-1],
|
||||
)
|
||||
sources = [
|
||||
Chunk.from_node(node) for node in streaming_response.source_nodes
|
||||
]
|
||||
completion_gen = CompletionGen(
|
||||
response=streaming_response.response_gen, sources=sources
|
||||
)
|
||||
else:
|
||||
stream = self.llm_service.llm.stream_chat(messages)
|
||||
completion_gen = CompletionGen(
|
||||
response=stream_chat_response_to_tokens(stream)
|
||||
)
|
||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||
last_message = (
|
||||
chat_engine_input.last_message.content
|
||||
if chat_engine_input.last_message
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
)
|
||||
|
||||
chat_engine = self._chat_engine(
|
||||
system_prompt=system_prompt,
|
||||
use_context=use_context,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
streaming_response = chat_engine.stream_chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=chat_history,
|
||||
)
|
||||
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
|
||||
completion_gen = CompletionGen(
|
||||
response=streaming_response.response_gen, sources=sources
|
||||
)
|
||||
return completion_gen
|
||||
|
||||
def chat(
|
||||
@@ -103,18 +158,30 @@ class ChatService:
|
||||
use_context: bool = False,
|
||||
context_filter: ContextFilter | None = None,
|
||||
) -> Completion:
|
||||
if use_context:
|
||||
last_message = messages[-1].content
|
||||
chat_engine = self._chat_engine(context_filter=context_filter)
|
||||
wrapped_response = chat_engine.chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=messages[:-1],
|
||||
)
|
||||
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
||||
completion = Completion(response=wrapped_response.response, sources=sources)
|
||||
else:
|
||||
chat_response = self.llm_service.llm.chat(messages)
|
||||
response_content = chat_response.message.content
|
||||
response = response_content if response_content is not None else ""
|
||||
completion = Completion(response=response)
|
||||
chat_engine_input = ChatEngineInput.from_messages(messages)
|
||||
last_message = (
|
||||
chat_engine_input.last_message.content
|
||||
if chat_engine_input.last_message
|
||||
else None
|
||||
)
|
||||
system_prompt = (
|
||||
chat_engine_input.system_message.content
|
||||
if chat_engine_input.system_message
|
||||
else None
|
||||
)
|
||||
chat_history = (
|
||||
chat_engine_input.chat_history if chat_engine_input.chat_history else None
|
||||
)
|
||||
|
||||
chat_engine = self._chat_engine(
|
||||
system_prompt=system_prompt,
|
||||
use_context=use_context,
|
||||
context_filter=context_filter,
|
||||
)
|
||||
wrapped_response = chat_engine.chat(
|
||||
message=last_message if last_message is not None else "",
|
||||
chat_history=chat_history,
|
||||
)
|
||||
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
|
||||
completion = Completion(response=wrapped_response.response, sources=sources)
|
||||
return completion
|
||||
|
Reference in New Issue
Block a user