From 133c1da13a6054541e1f0b307990db5e9b22f25d Mon Sep 17 00:00:00 2001 From: Ingrid Stevens Date: Wed, 28 Feb 2024 17:01:46 +0100 Subject: [PATCH] refines get_model_label() refines get_model_label() removes reliance on PGPT_PROFILES; Instead, uses settings().llm.mode. Possible options: "local", "openai", "openailike", "sagemaker", "mock", "ollama". --- private_gpt/ui/ui.py | 46 +++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 96585016..41b4afcc 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -1,7 +1,6 @@ -"""This file should be imported only and only if you want to run the UI locally.""" +"""This file should be imported if and only if you want to run the UI locally.""" import itertools import logging -import os import time from collections.abc import Iterable from pathlib import Path @@ -410,15 +409,40 @@ class PrivateGptUi: inputs=system_prompt_input, ) - def get_model_label() -> str | None: - # Determine the model label based on PGPT_PROFILES env variable. - pgpt_profiles = os.environ.get("PGPT_PROFILES") - if pgpt_profiles == "ollama": - return settings().ollama.model - elif pgpt_profiles == "vllm": - return settings().openai.model - else: - return None + def get_model_label() -> str: + """Get model label from llm mode setting YAML. + + Raises: + ValueError: If an invalid 'llm_mode' is encountered. + + Returns: + str: The corresponding model label. + """ + # Get model label from llm mode setting YAML + # Labels: local, openai, openailike, sagemaker, mock, ollama + config_settings = settings() + if config_settings is None: + raise ValueError("Settings are not configured.") + + # Get llm_mode from settings + llm_mode = config_settings.llm.mode + + # Mapping of 'llm_mode' to corresponding model labels + model_mapping = { + "local": config_settings.local.llm_hf_model_file, + "openai": config_settings.openai.model, + "openailike": config_settings.openai.model, + "sagemaker": config_settings.sagemaker.llm_endpoint_name, + "mock": llm_mode, + "ollama": config_settings.ollama.model, + } + + try: + return model_mapping[llm_mode] + except KeyError: + raise ValueError( + f"Invalid 'llm mode': {llm_mode}" + ) from None with gr.Column(scale=7, elem_id="col"): # Determine the model label based on the value of PGPT_PROFILES