diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 2c1dcd3e..4ad5d679 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -98,6 +98,8 @@ class PrivateGptUi: self._selected_filename = None + self._response_style = True + # Initialize system prompt based on default mode self.mode = MODES[0] self._system_prompt = self._get_default_system_prompt(self.mode) @@ -168,6 +170,12 @@ class PrivateGptUi: role=MessageRole.SYSTEM, ), ) + def draw_methods(service_type): + service = getattr(self, f'_{service_type}_service') + return { + True: getattr(service, f'stream_{service_type}'), + False: getattr(service, f'{service_type}') + } match mode: case Modes.RAG_MODE: # Use only the selected file for the query @@ -182,18 +190,20 @@ class PrivateGptUi: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - query_stream = self._chat_service.stream_chat( + methods = draw_methods('chat') + query_stream = methods.get(self._response_style, self._chat_service.stream_chat)( messages=all_messages, use_context=True, - context_filter=context_filter, + context_filter=context_filter ) - yield from yield_deltas(query_stream) + yield from (yield_deltas(query_stream) if self._response_style else [query_stream.response]) case Modes.BASIC_CHAT_MODE: - llm_stream = self._chat_service.stream_chat( + methods = draw_methods('chat') + llm_stream = methods.get(self._response_style, self._chat_service.stream_chat)( messages=all_messages, - use_context=False, + use_context=False ) - yield from yield_deltas(llm_stream) + yield from (yield_deltas(llm_stream) if self._response_style else [llm_stream.response]) case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( @@ -227,6 +237,15 @@ class PrivateGptUi: instructions=message, ) yield from yield_tokens(summary_stream) + ''' + methods = draw_methods('summarize') + summary_stream = methods.get(self._response_style, self._summarize_service.stream_summarize)( + use_context=True, + context_filter=context_filter, + instructions=message + ) + yield from yield_tokens(summary_stream) if response_style else summary_stream + ''' # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @@ -279,6 +298,9 @@ class PrivateGptUi: gr.update(value=self._explanation_mode), ] + def _set_response_style(self, response_style: str) -> None: + self._response_style = response_style + def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): @@ -402,6 +424,14 @@ class PrivateGptUi: max_lines=3, interactive=False, ) + response_style = gr.Checkbox( + label="Response Style: Streaming", + value=self._response_style + ) + response_style.input( + self._set_response_style, + inputs=response_style + ) upload_button = gr.components.UploadButton( "Upload File(s)", type="filepath",