This commit is contained in:
J 2024-12-03 09:16:25 +03:00 committed by GitHub
commit 168d0ce4c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -52,6 +52,14 @@ MODES: list[Modes] = [
]
class Styles(str, Enum):
STREAMING = "Streaming"
NON_STREAMING = "Non-Streaming"
STYLES: list[Styles] = [Styles.STREAMING, Styles.NON_STREAMING]
class Source(BaseModel):
file: str
page: str
@ -106,6 +114,9 @@ class PrivateGptUi:
)
self._system_prompt = self._get_default_system_prompt(self._default_mode)
# Initialize default response style: Streaming
self.response_style = STYLES[0]
def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
@ -186,18 +197,30 @@ class PrivateGptUi:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
match self.response_style:
case Styles.STREAMING:
query_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(query_stream)
case Styles.NON_STREAMING:
query_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [query_response]
case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
match self.response_style:
case Styles.STREAMING:
llm_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(llm_stream)
case Styles.NON_STREAMING:
llm_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [llm_response]
case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
@ -225,12 +248,21 @@ class PrivateGptUi:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
match self.response_style:
case Styles.STREAMING:
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
case Styles.NON_STREAMING:
summary_response = self._summarize_service.summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from summary_response
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@ -283,6 +315,9 @@ class PrivateGptUi:
gr.update(value=self._explanation_mode),
]
def _set_current_response_style(self, response_style: Styles) -> Any:
self.response_style = response_style
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
@ -406,6 +441,15 @@ class PrivateGptUi:
max_lines=3,
interactive=False,
)
default_response_style = STYLES[0]
response_style = (
gr.Dropdown(
[response_style.value for response_style in STYLES],
label="Response Style",
value=default_response_style,
interactive=True,
),
)
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
@ -499,6 +543,10 @@ class PrivateGptUi:
self._set_system_prompt,
inputs=system_prompt_input,
)
# When response style changes
response_style[0].change(
self._set_current_response_style, inputs=response_style
)
def get_model_label() -> str | None:
"""Get model label from llm mode setting YAML.