diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 56136168..d538093d 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -51,6 +51,17 @@ MODES: list[Modes] = [ ] +class Styles(str, Enum): + STREAMING = "Streaming" + NON_STREAMING = "Non-Streaming" + + +STYLES: list[Styles] = [ + Styles.STREAMING, + Styles.NON_STREAMING +] + + class Source(BaseModel): file: str page: str @@ -98,12 +109,13 @@ class PrivateGptUi: self._selected_filename = None - self._response_style = True - # Initialize system prompt based on default mode self.mode = MODES[0] self._system_prompt = self._get_default_system_prompt(self.mode) + # Initialize default response style: Streaming + self.response_style = STYLES[0] + def _chat( self, message: str, history: list[list[str]], mode: Modes, *_: Any ) -> Any: @@ -184,28 +196,30 @@ class PrivateGptUi: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - if self._response_style: - query_stream = self._chat_service.stream_chat( - all_messages, use_context=False - ) - yield from yield_deltas(query_stream) - else: - query_response = self._chat_service.chat( - all_messages, use_context=False - ).response - yield from [query_response] + match self.response_style: + case Styles.STREAMING: + query_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(query_stream) + case Styles.NON_STREAMING: + query_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [query_response] case Modes.BASIC_CHAT_MODE: - if self._response_style: - llm_stream = self._chat_service.stream_chat( - all_messages, use_context=False - ) - yield from yield_deltas(llm_stream) - else: - llm_response = self._chat_service.chat( - all_messages, use_context=False - ).response - yield from [llm_response] + match self.response_style: + case Styles.STREAMING: + llm_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(llm_stream) + case Styles.NON_STREAMING: + llm_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [llm_response] case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( @@ -233,20 +247,21 @@ class PrivateGptUi: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - if self._response_style: - summary_stream = self._summarize_service.stream_summarize( - use_context=True, - context_filter=context_filter, - instructions=message, - ) - yield from yield_tokens(summary_stream) - else: - summary_response = self._summarize_service.summarize( - use_context=True, - context_filter=context_filter, - instructions=message, - ) - yield from summary_response + match self.response_style: + case Styles.STREAMING: + summary_stream = self._summarize_service.stream_summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from yield_tokens(summary_stream) + case Styles.NON_STREAMING: + summary_response = self._summarize_service.summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from summary_response # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @@ -299,8 +314,8 @@ class PrivateGptUi: gr.update(value=self._explanation_mode), ] - def _set_response_style(self, response_style: bool) -> None: - self._response_style = response_style + def _set_current_response_style(self, response_style: Styles) -> Any: + self.response_style = response_style def _list_ingested_files(self) -> list[list[str]]: files = set() @@ -425,11 +440,14 @@ class PrivateGptUi: max_lines=3, interactive=False, ) - response_style = gr.Checkbox( - label="Response Style: Streaming", value=self._response_style - ) - response_style.input( - self._set_response_style, inputs=response_style + default_response_style = STYLES[0] + response_style = ( + gr.Dropdown( + [response_style.value for response_style in STYLES], + label="Response Style", + value=default_response_style, + interactive=True, + ), ) upload_button = gr.components.UploadButton( "Upload File(s)", @@ -524,6 +542,10 @@ class PrivateGptUi: self._set_system_prompt, inputs=system_prompt_input, ) + # When response style changes + response_style[0].change( + self._set_current_response_style, inputs=response_style + ) def get_model_label() -> str | None: """Get model label from llm mode setting YAML.