refactor: Refactor proxy LLM (#1064)

This commit is contained in:
Fangyin Cheng
2024-01-14 21:01:37 +08:00
committed by GitHub
parent a035433170
commit 22bfd01c4b
95 changed files with 2049 additions and 1294 deletions

View File

@@ -152,6 +152,17 @@ class LLMModelAdapter(ABC):
except Exception:
return "\n"
def get_prompt_roles(self) -> List[str]:
"""Get the roles of the prompt
Returns:
List[str]: The roles of the prompt
"""
roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI]
if self.support_system_message:
roles.append(ModelMessageRoleType.SYSTEM)
return roles
def transform_model_messages(
self, messages: List[ModelMessage], convert_to_compatible_format: bool = False
) -> List[Dict[str, str]]:
@@ -185,7 +196,7 @@ class LLMModelAdapter(ABC):
# We will not do any transform in the future
return self._transform_to_no_system_messages(messages)
else:
return ModelMessage.to_openai_messages(
return ModelMessage.to_common_messages(
messages, convert_to_compatible_format=convert_to_compatible_format
)
@@ -216,7 +227,7 @@ class LLMModelAdapter(ABC):
Returns:
List[Dict[str, str]]: The transformed model messages
"""
openai_messages = ModelMessage.to_openai_messages(messages)
openai_messages = ModelMessage.to_common_messages(messages)
system_messages = []
return_messages = []
for message in openai_messages:
@@ -394,6 +405,9 @@ class LLMModelAdapter(ABC):
conv.set_system_message("".join(can_use_systems))
return conv
def apply_conv_template(self) -> bool:
return self.model_type() != ModelType.PROXY
def model_adaptation(
self,
params: Dict,
@@ -414,7 +428,11 @@ class LLMModelAdapter(ABC):
params["convert_to_compatible_format"] = convert_to_compatible_format
# Some model context to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
model_context = {
"prompt_echo_len_char": -1,
"has_format_prompt": False,
"echo": params.get("echo", True),
}
if messages:
# Dict message to ModelMessage
messages = [
@@ -422,6 +440,11 @@ class LLMModelAdapter(ABC):
for m in messages
]
params["messages"] = messages
params["string_prompt"] = ModelMessage.messages_to_string(messages)
if not self.apply_conv_template():
# No need to apply conversation template, now for proxy LLM
return params, model_context
new_prompt = self.get_str_prompt(
params, messages, tokenizer, prompt_template, convert_to_compatible_format
@@ -442,7 +465,6 @@ class LLMModelAdapter(ABC):
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt