refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -152,6 +152,17 @@ class LLMModelAdapter(ABC):
except Exception:
return "\n"
def get_prompt_roles(self) -> List[str]:
"""Get the roles of the prompt
Returns:
List[str]: The roles of the prompt
"""
roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI]
if self.support_system_message:
roles.append(ModelMessageRoleType.SYSTEM)
return roles
def transform_model_messages(
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]:
@@ -185,7 +196,7 @@ class LLMModelAdapter(ABC):
# We will not do any transform in the future
return self._transform_to_no_system_messages(messages)
else:
return ModelMessage.to_openai_messages(
return ModelMessage.to_common_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
@@ -216,7 +227,7 @@ class LLMModelAdapter(ABC):
Returns:
List[Dict[str, str]]: The transformed model messages
"""
openai_messages = ModelMessage.to_openai_messages(messages)
openai_messages = ModelMessage.to_common_messages(messages)
system_messages = []
return_messages = []
for message in openai_messages:
@@ -394,6 +405,9 @@ class LLMModelAdapter(ABC):
conv.set_system_message("".join(can_use_systems))
return conv
def apply_conv_template(self) -> bool:
return self.model_type() != ModelType.PROXY
def model_adaptation(
self,
params: Dict,
@@ -414,7 +428,11 @@ class LLMModelAdapter(ABC):
params["convert_to_compatible_format"] = convert_to_compatible_format
# Some model context to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
model_context = {
"prompt_echo_len_char": -1,
"has_format_prompt": False,
"echo": params.get("echo", True),
}
if messages:
# Dict message to ModelMessage
messages = [
@@ -422,6 +440,11 @@ class LLMModelAdapter(ABC):
for m in messages
]
params["messages"] = messages
params["string_prompt"] = ModelMessage.messages_to_string(messages)
if not self.apply_conv_template():
# No need to apply conversation template, now for proxy LLM
return params, model_context
new_prompt = self.get_str_prompt(
params, messages, tokenizer, prompt_template, convert_to_compatible_format
@@ -442,7 +465,6 @@ class LLMModelAdapter(ABC):
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt

View File

@@ -19,7 +19,7 @@ _IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
# "proxyllm",
"gptj-6b",
"codellama-13b-sql-sft",
"codellama-7b",
@@ -45,6 +45,7 @@ def get_llm_model_adapter(
# Import NewHFChatModelAdapter for it can be registered
from dbgpt.model.adapter.hf_adapter import NewHFChatModelAdapter
from dbgpt.model.adapter.proxy_adapter import ProxyLLMModelAdapter
new_model_adapter = get_model_adapter(
model_type, model_name, model_path, conv_factory

View File

@@ -0,0 +1,238 @@
import dataclasses
import logging
from abc import abstractmethod
from typing import TYPE_CHECKING, List, Optional, Type, Union
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
from dbgpt.model.base import ModelType
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel
logger = logging.getLogger(__name__)
class ProxyLLMModelAdapter(LLMModelAdapter):
def new_adapter(self, **kwargs) -> "LLMModelAdapter":
return self.__class__()
def model_type(self) -> str:
return ModelType.PROXY
def match(
self,
model_type: str,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
) -> bool:
model_name = model_name.lower() if model_name else None
model_path = model_path.lower() if model_path else None
return self.do_match(model_name) or self.do_match(model_path)
@abstractmethod
def do_match(self, lower_model_name_or_path: Optional[str] = None):
raise NotImplementedError()
def dynamic_llm_client_class(
self, params: ProxyModelParameters
) -> Optional[Type[ProxyLLMClient]]:
"""Get dynamic llm client class
Parse the llm_client_class from params and return the class
Args:
params (ProxyModelParameters): proxy model parameters
Returns:
Optional[Type[ProxyLLMClient]]: llm client class
"""
if params.llm_client_class:
from dbgpt.util.module_utils import import_from_checked_string
worker_cls: Type[ProxyLLMClient] = import_from_checked_string(
params.llm_client_class, ProxyLLMClient
)
return worker_cls
return None
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
"""Get llm client class"""
dynamic_llm_client_class = self.dynamic_llm_client_class(params)
if dynamic_llm_client_class:
return dynamic_llm_client_class
raise NotImplementedError()
def load_from_params(self, params: ProxyModelParameters):
dynamic_llm_client_class = self.dynamic_llm_client_class(params)
if not dynamic_llm_client_class:
dynamic_llm_client_class = self.get_llm_client_class(params)
logger.info(
f"Load model from params: {params}, llm client class: {dynamic_llm_client_class}"
)
proxy_llm_client = dynamic_llm_client_class.new_client(params)
model = ProxyModel(params, proxy_llm_client)
return model, model
class OpenAIProxyLLMModelAdapter(ProxyLLMModelAdapter):
def support_async(self) -> bool:
return True
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path in ["chatgpt_proxyllm", "proxyllm"]
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
"""Get llm client class"""
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
return OpenAILLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.chatgpt import chatgpt_generate_stream
return chatgpt_generate_stream
class TongyiProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "tongyi_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.tongyi import TongyiLLMClient
return TongyiLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.tongyi import tongyi_generate_stream
return tongyi_generate_stream
class ZhipuProxyLLMModelAdapter(ProxyLLMModelAdapter):
support_system_message = False
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "zhipu_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.zhipu import ZhipuLLMClient
return ZhipuLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.zhipu import zhipu_generate_stream
return zhipu_generate_stream
class WenxinProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "wenxin_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.wenxin import WenxinLLMClient
return WenxinLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.wenxin import wenxin_generate_stream
return wenxin_generate_stream
class GeminiProxyLLMModelAdapter(ProxyLLMModelAdapter):
support_system_message = False
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "gemini_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.gemini import GeminiLLMClient
return GeminiLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.gemini import gemini_generate_stream
return gemini_generate_stream
class SparkProxyLLMModelAdapter(ProxyLLMModelAdapter):
support_system_message = False
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "spark_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
from dbgpt.model.proxy.llms.spark import SparkLLMClient
return SparkLLMClient
def get_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.spark import spark_generate_stream
return spark_generate_stream
class BardProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "bard_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
"""Get llm client class"""
# TODO: Bard proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
return OpenAILLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.bard import bard_generate_stream
return bard_generate_stream
class BaichuanProxyLLMModelAdapter(ProxyLLMModelAdapter):
def do_match(self, lower_model_name_or_path: Optional[str] = None):
return lower_model_name_or_path == "bc_proxyllm"
def get_llm_client_class(
self, params: ProxyModelParameters
) -> Type[ProxyLLMClient]:
"""Get llm client class"""
# TODO: Baichuan proxy LLM not support ProxyLLMClient now, we just return OpenAILLMClient
from dbgpt.model.proxy.llms.chatgpt import OpenAILLMClient
return OpenAILLMClient
def get_async_generate_stream_function(self, model, model_path: str):
from dbgpt.model.proxy.llms.baichuan import baichuan_generate_stream
return baichuan_generate_stream
register_model_adapter(OpenAIProxyLLMModelAdapter)
register_model_adapter(TongyiProxyLLMModelAdapter)
register_model_adapter(ZhipuProxyLLMModelAdapter)
register_model_adapter(WenxinProxyLLMModelAdapter)
register_model_adapter(GeminiProxyLLMModelAdapter)
register_model_adapter(SparkProxyLLMModelAdapter)
register_model_adapter(BardProxyLLMModelAdapter)
register_model_adapter(BaichuanProxyLLMModelAdapter)