mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 12:30:14 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -152,6 +152,17 @@ class LLMModelAdapter(ABC):
|
||||
except Exception:
|
||||
return "\n"
|
||||
|
||||
def get_prompt_roles(self) -> List[str]:
|
||||
"""Get the roles of the prompt
|
||||
|
||||
Returns:
|
||||
List[str]: The roles of the prompt
|
||||
"""
|
||||
roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI]
|
||||
if self.support_system_message:
|
||||
roles.append(ModelMessageRoleType.SYSTEM)
|
||||
return roles
|
||||
|
||||
def transform_model_messages(
|
||||
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
@@ -185,7 +196,7 @@ class LLMModelAdapter(ABC):
|
||||
# We will not do any transform in the future
|
||||
return self._transform_to_no_system_messages(messages)
|
||||
else:
|
||||
return ModelMessage.to_openai_messages(
|
||||
return ModelMessage.to_common_messages(
|
||||
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||
)
|
||||
|
||||
@@ -216,7 +227,7 @@ class LLMModelAdapter(ABC):
|
||||
Returns:
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
openai_messages = ModelMessage.to_openai_messages(messages)
|
||||
openai_messages = ModelMessage.to_common_messages(messages)
|
||||
system_messages = []
|
||||
return_messages = []
|
||||
for message in openai_messages:
|
||||
@@ -394,6 +405,9 @@ class LLMModelAdapter(ABC):
|
||||
conv.set_system_message("".join(can_use_systems))
|
||||
return conv
|
||||
|
||||
def apply_conv_template(self) -> bool:
|
||||
return self.model_type() != ModelType.PROXY
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
@@ -414,7 +428,11 @@ class LLMModelAdapter(ABC):
|
||||
params["convert_to_compatible_format"] = convert_to_compatible_format
|
||||
|
||||
# Some model context to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
model_context = {
|
||||
"prompt_echo_len_char": -1,
|
||||
"has_format_prompt": False,
|
||||
"echo": params.get("echo", True),
|
||||
}
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
@@ -422,6 +440,11 @@ class LLMModelAdapter(ABC):
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
params["string_prompt"] = ModelMessage.messages_to_string(messages)
|
||||
|
||||
if not self.apply_conv_template():
|
||||
# No need to apply conversation template, now for proxy LLM
|
||||
return params, model_context
|
||||
|
||||
new_prompt = self.get_str_prompt(
|
||||
params, messages, tokenizer, prompt_template, convert_to_compatible_format
|
||||
@@ -442,7 +465,6 @@ class LLMModelAdapter(ABC):
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
|
@@ -19,7 +19,7 @@ _IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||
|
||||
_OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
# "proxyllm",
|
||||
"gptj-6b",
|
||||
"codellama-13b-sql-sft",
|
||||
"codellama-7b",
|
||||
@@ -45,6 +45,7 @@ def get_llm_model_adapter(
|
||||
|
||||
# Import NewHFChatModelAdapter for it can be registered
|
||||
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
|
||||
from dbgpt.model.adapter.proxy_adapter import ProxyLLMModelAdapter
|
||||
|
||||
new_model_adapter = get_model_adapter(
|
||||
model_type, model_name, model_path, conv_factory
|
||||
|
238
dbgpt/model/adapter/proxy_adapter.py
Normal file
238
dbgpt/model/adapter/proxy_adapter.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
||||
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.model.parameter import ProxyModelParameters
|
||||
from dbgpt.model.proxy.base import ProxyLLMClient
|
||||
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProxyLLMModelAdapter(LLMModelAdapter):
|
||||
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
|
||||
return self.__class__()
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.PROXY
|
||||
|
||||
def match(
|
||||
self,
|
||||
model_type: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
model_name = model_name.lower() if model_name else None
|
||||
model_path = model_path.lower() if model_path else None
|
||||
return self.do_match(model_name) or self.do_match(model_path)
|
||||
|
||||
@abstractmethod
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dynamic_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Optional[Type[ProxyLLMClient]]:
|
||||
"""Get dynamic llm client class
|
||||
|
||||
Parse the llm_client_class from params and return the class
|
||||
|
||||
Args:
|
||||
params (ProxyModelParameters): proxy model parameters
|
||||
|
||||
Returns:
|
||||
Optional[Type[ProxyLLMClient]]: llm client class
|
||||
"""
|
||||
|
||||
if params.llm_client_class:
|
||||
from dbgpt.util.module_utils import import_from_checked_string
|
||||
|
||||
worker_cls: Type[ProxyLLMClient] = import_from_checked_string(
|
||||
params.llm_client_class, ProxyLLMClient
|
||||
)
|
||||
return worker_cls
|
||||
return None
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
"""Get llm client class"""
|
||||
dynamic_llm_client_class = self.dynamic_llm_client_class(params)
|
||||
if dynamic_llm_client_class:
|
||||
return dynamic_llm_client_class
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_from_params(self, params: ProxyModelParameters):
|
||||
dynamic_llm_client_class = self.dynamic_llm_client_class(params)
|
||||
if not dynamic_llm_client_class:
|
||||
dynamic_llm_client_class = self.get_llm_client_class(params)
|
||||
logger.info(
|
||||
f"Load model from params: {params}, llm client class: {dynamic_llm_client_class}"
|
||||
)
|
||||
proxy_llm_client = dynamic_llm_client_class.new_client(params)
|
||||
model = ProxyModel(params, proxy_llm_client)
|
||||
return model, model
|
||||
|
||||
|
||||
class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path in ["chatgpt_proxyllm", "proxyllm"]
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
"""Get llm client class"""
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
return OpenAILLMClient
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
|
||||
|
||||
return chatgpt_generate_stream
|
||||
|
||||
|
||||
class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "tongyi_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
|
||||
|
||||
return TongyiLLMClient
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
|
||||
|
||||
return tongyi_generate_stream
|
||||
|
||||
|
||||
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
support_system_message = False
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "zhipu_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient
|
||||
|
||||
return ZhipuLLMClient
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
|
||||
|
||||
return zhipu_generate_stream
|
||||
|
||||
|
||||
class WenxinProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "wenxin_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
|
||||
|
||||
return WenxinLLMClient
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
|
||||
|
||||
return wenxin_generate_stream
|
||||
|
||||
|
||||
class GeminiProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
support_system_message = False
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "gemini_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
|
||||
|
||||
return GeminiLLMClient
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
|
||||
|
||||
return gemini_generate_stream
|
||||
|
||||
|
||||
class SparkProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
support_system_message = False
|
||||
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "spark_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
from dbgpt.model.proxy.llms.spark import SparkLLMClient
|
||||
|
||||
return SparkLLMClient
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.spark import spark_generate_stream
|
||||
|
||||
return spark_generate_stream
|
||||
|
||||
|
||||
class BardProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "bard_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
"""Get llm client class"""
|
||||
# TODO: Bard proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
return OpenAILLMClient
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.bard import bard_generate_stream
|
||||
|
||||
return bard_generate_stream
|
||||
|
||||
|
||||
class BaichuanProxyLLMModelAdapter(ProxyLLMModelAdapter):
|
||||
def do_match(self, lower_model_name_or_path: Optional[str] = None):
|
||||
return lower_model_name_or_path == "bc_proxyllm"
|
||||
|
||||
def get_llm_client_class(
|
||||
self, params: ProxyModelParameters
|
||||
) -> Type[ProxyLLMClient]:
|
||||
"""Get llm client class"""
|
||||
# TODO: Baichuan proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient
|
||||
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
|
||||
|
||||
return OpenAILLMClient
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
|
||||
|
||||
return baichuan_generate_stream
|
||||
|
||||
|
||||
register_model_adapter(OpenAIProxyLLMModelAdapter)
|
||||
register_model_adapter(TongyiProxyLLMModelAdapter)
|
||||
register_model_adapter(ZhipuProxyLLMModelAdapter)
|
||||
register_model_adapter(WenxinProxyLLMModelAdapter)
|
||||
register_model_adapter(GeminiProxyLLMModelAdapter)
|
||||
register_model_adapter(SparkProxyLLMModelAdapter)
|
||||
register_model_adapter(BardProxyLLMModelAdapter)
|
||||
register_model_adapter(BaichuanProxyLLMModelAdapter)
|
Reference in New Issue
Block a user