Add checkbox to choose between streaming and non-streaming response output

This commit is contained in:
Jason
2024-08-26 13:25:09 -04:00
parent ea9fbb4513
commit 079761ea3c

View File

@@ -98,6 +98,8 @@ class PrivateGptUi:
self._selected_filename = None self._selected_filename = None
self._response_style = True
# Initialize system prompt based on default mode # Initialize system prompt based on default mode
self.mode = MODES[0] self.mode = MODES[0]
self._system_prompt = self._get_default_system_prompt(self.mode) self._system_prompt = self._get_default_system_prompt(self.mode)
@@ -168,6 +170,12 @@ class PrivateGptUi:
role=MessageRole.SYSTEM, role=MessageRole.SYSTEM,
), ),
) )
def draw_methods(service_type):
service = getattr(self, f'_{service_type}_service')
return {
True: getattr(service, f'stream_{service_type}'),
False: getattr(service, f'{service_type}')
}
match mode: match mode:
case Modes.RAG_MODE: case Modes.RAG_MODE:
# Use only the selected file for the query # Use only the selected file for the query
@@ -182,18 +190,20 @@ class PrivateGptUi:
docs_ids.append(ingested_document.doc_id) docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids) context_filter = ContextFilter(docs_ids=docs_ids)
query_stream = self._chat_service.stream_chat( methods = draw_methods('chat')
query_stream = methods.get(self._response_style, self._chat_service.stream_chat)(
messages=all_messages, messages=all_messages,
use_context=True, use_context=True,
context_filter=context_filter, context_filter=context_filter
) )
yield from yield_deltas(query_stream) yield from (yield_deltas(query_stream) if self._response_style else [query_stream.response])
case Modes.BASIC_CHAT_MODE: case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat( methods = draw_methods('chat')
llm_stream = methods.get(self._response_style, self._chat_service.stream_chat)(
messages=all_messages, messages=all_messages,
use_context=False, use_context=False
) )
yield from yield_deltas(llm_stream) yield from (yield_deltas(llm_stream) if self._response_style else [llm_stream.response])
case Modes.SEARCH_MODE: case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant( response = self._chunks_service.retrieve_relevant(
@@ -227,6 +237,15 @@ class PrivateGptUi:
instructions=message, instructions=message,
) )
yield from yield_tokens(summary_stream) yield from yield_tokens(summary_stream)
'''
methods = draw_methods('summarize')
summary_stream = methods.get(self._response_style, self._summarize_service.stream_summarize)(
use_context=True,
context_filter=context_filter,
instructions=message
)
yield from yield_tokens(summary_stream) if response_style else summary_stream
'''
# On initialization and on mode change, this function set the system prompt # On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings). # to the default prompt based on the mode (and user settings).
@@ -279,6 +298,9 @@ class PrivateGptUi:
gr.update(value=self._explanation_mode), gr.update(value=self._explanation_mode),
] ]
def _set_response_style(self, response_style: str) -> None:
self._response_style = response_style
def _list_ingested_files(self) -> list[list[str]]: def _list_ingested_files(self) -> list[list[str]]:
files = set() files = set()
for ingested_document in self._ingest_service.list_ingested(): for ingested_document in self._ingest_service.list_ingested():
@@ -402,6 +424,14 @@ class PrivateGptUi:
max_lines=3, max_lines=3,
interactive=False, interactive=False,
) )
response_style = gr.Checkbox(
label="Response Style: Streaming",
value=self._response_style
)
response_style.input(
self._set_response_style,
inputs=response_style
)
upload_button = gr.components.UploadButton( upload_button = gr.components.UploadButton(
"Upload File(s)", "Upload File(s)",
type="filepath", type="filepath",