diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index abdfb0c6..52276f12 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -52,6 +52,14 @@ MODES: list[Modes] = [ ] +class Styles(str, Enum): + STREAMING = "Streaming" + NON_STREAMING = "Non-Streaming" + + +STYLES: list[Styles] = [Styles.STREAMING, Styles.NON_STREAMING] + + class Source(BaseModel): file: str page: str @@ -106,6 +114,9 @@ class PrivateGptUi: ) self._system_prompt = self._get_default_system_prompt(self._default_mode) + # Initialize default response style: Streaming + self.response_style = STYLES[0] + def _chat( self, message: str, history: list[list[str]], mode: Modes, *_: Any ) -> Any: @@ -186,18 +197,30 @@ class PrivateGptUi: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - query_stream = self._chat_service.stream_chat( - messages=all_messages, - use_context=True, - context_filter=context_filter, - ) - yield from yield_deltas(query_stream) + match self.response_style: + case Styles.STREAMING: + query_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(query_stream) + case Styles.NON_STREAMING: + query_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [query_response] + case Modes.BASIC_CHAT_MODE: - llm_stream = self._chat_service.stream_chat( - messages=all_messages, - use_context=False, - ) - yield from yield_deltas(llm_stream) + match self.response_style: + case Styles.STREAMING: + llm_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(llm_stream) + case Styles.NON_STREAMING: + llm_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [llm_response] case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( @@ -225,12 +248,21 @@ class PrivateGptUi: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - summary_stream = self._summarize_service.stream_summarize( - use_context=True, - context_filter=context_filter, - instructions=message, - ) - yield from yield_tokens(summary_stream) + match self.response_style: + case Styles.STREAMING: + summary_stream = self._summarize_service.stream_summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from yield_tokens(summary_stream) + case Styles.NON_STREAMING: + summary_response = self._summarize_service.summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from summary_response # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @@ -283,6 +315,9 @@ class PrivateGptUi: gr.update(value=self._explanation_mode), ] + def _set_current_response_style(self, response_style: Styles) -> Any: + self.response_style = response_style + def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): @@ -406,6 +441,15 @@ class PrivateGptUi: max_lines=3, interactive=False, ) + default_response_style = STYLES[0] + response_style = ( + gr.Dropdown( + [response_style.value for response_style in STYLES], + label="Response Style", + value=default_response_style, + interactive=True, + ), + ) upload_button = gr.components.UploadButton( "Upload File(s)", type="filepath", @@ -499,6 +543,10 @@ class PrivateGptUi: self._set_system_prompt, inputs=system_prompt_input, ) + # When response style changes + response_style[0].change( + self._set_current_response_style, inputs=response_style + ) def get_model_label() -> str | None: """Get model label from llm mode setting YAML.