mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
refactor: Refactor proxy LLM (#1064)
This commit is contained in:
@@ -152,6 +152,17 @@ class LLMModelAdapter(ABC):
|
||||
except Exception:
|
||||
return "\n"
|
||||
|
||||
def get_prompt_roles(self) -> List[str]:
|
||||
"""Get the roles of the prompt
|
||||
|
||||
Returns:
|
||||
List[str]: The roles of the prompt
|
||||
"""
|
||||
roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI]
|
||||
if self.support_system_message:
|
||||
roles.append(ModelMessageRoleType.SYSTEM)
|
||||
return roles
|
||||
|
||||
def transform_model_messages(
|
||||
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
|
||||
) -> List[Dict[str, str]]:
|
||||
@@ -185,7 +196,7 @@ class LLMModelAdapter(ABC):
|
||||
# We will not do any transform in the future
|
||||
return self._transform_to_no_system_messages(messages)
|
||||
else:
|
||||
return ModelMessage.to_openai_messages(
|
||||
return ModelMessage.to_common_messages(
|
||||
messages, convert_to_compatible_format=convert_to_compatible_format
|
||||
)
|
||||
|
||||
@@ -216,7 +227,7 @@ class LLMModelAdapter(ABC):
|
||||
Returns:
|
||||
List[Dict[str, str]]: The transformed model messages
|
||||
"""
|
||||
openai_messages = ModelMessage.to_openai_messages(messages)
|
||||
openai_messages = ModelMessage.to_common_messages(messages)
|
||||
system_messages = []
|
||||
return_messages = []
|
||||
for message in openai_messages:
|
||||
@@ -394,6 +405,9 @@ class LLMModelAdapter(ABC):
|
||||
conv.set_system_message("".join(can_use_systems))
|
||||
return conv
|
||||
|
||||
def apply_conv_template(self) -> bool:
|
||||
return self.model_type() != ModelType.PROXY
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
@@ -414,7 +428,11 @@ class LLMModelAdapter(ABC):
|
||||
params["convert_to_compatible_format"] = convert_to_compatible_format
|
||||
|
||||
# Some model context to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
model_context = {
|
||||
"prompt_echo_len_char": -1,
|
||||
"has_format_prompt": False,
|
||||
"echo": params.get("echo", True),
|
||||
}
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
@@ -422,6 +440,11 @@ class LLMModelAdapter(ABC):
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
params["string_prompt"] = ModelMessage.messages_to_string(messages)
|
||||
|
||||
if not self.apply_conv_template():
|
||||
# No need to apply conversation template, now for proxy LLM
|
||||
return params, model_context
|
||||
|
||||
new_prompt = self.get_str_prompt(
|
||||
params, messages, tokenizer, prompt_template, convert_to_compatible_format
|
||||
@@ -442,7 +465,6 @@ class LLMModelAdapter(ABC):
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
|
Reference in New Issue
Block a user