import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Tuple, Type from dbgpt.core.interface.message import ModelMessage, ModelMessageRoleType from dbgpt.model.adapter.template import ( ConversationAdapter, ConversationAdapterFactory, get_conv_template, ) from dbgpt.model.base import ModelType from dbgpt.model.parameter import ( BaseModelParameters, LlamaCppModelParameters, ModelParameters, ProxyModelParameters, ) logger = logging.getLogger(__name__) class LLMModelAdapter(ABC): """New Adapter for DB-GPT LLM models""" model_name: Optional[str] = None model_path: Optional[str] = None conv_factory: Optional[ConversationAdapterFactory] = None # TODO: more flexible quantization config support_4bit: bool = False support_8bit: bool = False support_system_message: bool = True def __repr__(self) -> str: return f"<{self.__class__.__name__} model_name={self.model_name} model_path={self.model_path}>" def __str__(self): return self.__repr__() @abstractmethod def new_adapter(self, **kwargs) -> "LLMModelAdapter": """Create a new adapter instance Args: **kwargs: The parameters of the new adapter instance Returns: LLMModelAdapter: The new adapter instance """ def use_fast_tokenizer(self) -> bool: """Whether use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) if it is supported for a given model. """ return False def model_type(self) -> str: return ModelType.HF def model_param_class(self, model_type: str = None) -> Type[BaseModelParameters]: """Get the startup parameters instance of the model Args: model_type (str, optional): The type of model. Defaults to None. Returns: Type[BaseModelParameters]: The startup parameters instance of the model """ # """Get the startup parameters instance of the model""" model_type = model_type if model_type else self.model_type() if model_type == ModelType.LLAMA_CPP: return LlamaCppModelParameters elif model_type == ModelType.PROXY: return ProxyModelParameters return ModelParameters def match( self, model_type: str, model_name: Optional[str] = None, model_path: Optional[str] = None, ) -> bool: """Whether the model adapter can load the given model Args: model_type (str): The type of model model_name (Optional[str], optional): The name of model. Defaults to None. model_path (Optional[str], optional): The path of model. Defaults to None. """ return False def support_quantization_4bit(self) -> bool: """Whether the model adapter can load 4bit model If it is True, we will load the 4bit model with :meth:`~LLMModelAdapter.load` Returns: bool: Whether the model adapter can load 4bit model, default is False """ return self.support_4bit def support_quantization_8bit(self) -> bool: """Whether the model adapter can load 8bit model If it is True, we will load the 8bit model with :meth:`~LLMModelAdapter.load` Returns: bool: Whether the model adapter can load 8bit model, default is False """ return self.support_8bit def load(self, model_path: str, from_pretrained_kwargs: dict): """Load model and tokenizer""" raise NotImplementedError def parse_max_length(self, model, tokenizer) -> Optional[int]: """Parse the max_length of the model. Returns: Optional[int]: The max_length of the model """ if not (tokenizer or model): return None try: model_max_length = None if tokenizer and hasattr(tokenizer, "model_max_length"): model_max_length = tokenizer.model_max_length if model_max_length and model_max_length < 100000000: # Can't be too large return model_max_length if model and hasattr(model, "config"): model_config = model.config if hasattr(model_config, "max_sequence_length"): return model_config.max_sequence_length if hasattr(model_config, "max_position_embeddings"): return model_config.max_position_embeddings return None except Exception: return None def load_from_params(self, params): """Load the model and tokenizer according to the given parameters""" raise NotImplementedError def support_async(self) -> bool: """Whether the loaded model supports asynchronous calls""" return False def get_generate_stream_function(self, model, model_path: str): """Get the generate stream function of the model""" raise NotImplementedError def get_async_generate_stream_function(self, model, model_path: str): """Get the asynchronous generate stream function of the model""" raise NotImplementedError def get_default_conv_template( self, model_name: str, model_path: str ) -> Optional[ConversationAdapter]: """Get the default conversation template Args: model_name (str): The name of the model. model_path (str): The path of the model. Returns: Optional[ConversationAdapter]: The conversation template. """ raise NotImplementedError def get_default_message_separator(self) -> str: """Get the default message separator""" try: conv_template = self.get_default_conv_template( self.model_name, self.model_path ) return conv_template.sep except Exception: return "\n" def get_prompt_roles(self) -> List[str]: """Get the roles of the prompt Returns: List[str]: The roles of the prompt """ roles = [ModelMessageRoleType.HUMAN, ModelMessageRoleType.AI] if self.support_system_message: roles.append(ModelMessageRoleType.SYSTEM) return roles def transform_model_messages( self, messages: List[ModelMessage], convert_to_compatible_format: bool = False ) -> List[Dict[str, str]]: """Transform the model messages Default is the OpenAI format, example: .. code-block:: python return_messages = [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}, ] But some model may need to transform the messages to other format(e.g. There is no system message), such as: .. code-block:: python return_messages = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}, ] Args: messages (List[ModelMessage]): The model messages convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False. Returns: List[Dict[str, str]]: The transformed model messages """ logger.info(f"support_system_message: {self.support_system_message}") if not self.support_system_message and convert_to_compatible_format: # We will not do any transform in the future return self._transform_to_no_system_messages(messages) else: return ModelMessage.to_common_messages( messages, convert_to_compatible_format=convert_to_compatible_format ) def _transform_to_no_system_messages( self, messages: List[ModelMessage] ) -> List[Dict[str, str]]: """Transform the model messages to no system messages Some opensource chat model no system messages, so wo should transform the messages to no system messages. Merge the system messages to the last user message, example: .. code-block:: python return_messages = [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi"}, ] => return_messages = [ {"role": "user", "content": "You are a helpful assistant\nHello"}, {"role": "assistant", "content": "Hi"} ] Args: messages (List[ModelMessage]): The model messages Returns: List[Dict[str, str]]: The transformed model messages """ openai_messages = ModelMessage.to_common_messages(messages) system_messages = [] return_messages = [] for message in openai_messages: if message["role"] == "system": system_messages.append(message["content"]) else: return_messages.append(message) if len(system_messages) > 1: # Too much system messages should be a warning logger.warning("Your system messages have more than one message") if system_messages: sep = self.get_default_message_separator() str_system_messages = ",".join(system_messages) # Update last user message return_messages[-1]["content"] = ( str_system_messages + sep + return_messages[-1]["content"] ) return return_messages def get_str_prompt( self, params: Dict, messages: List[ModelMessage], tokenizer: Any, prompt_template: str = None, convert_to_compatible_format: bool = False, ) -> Optional[str]: """Get the string prompt from the given parameters and messages If the value of return is not None, we will skip :meth:`~LLMModelAdapter.get_prompt_with_template` and use the value of return. Args: params (Dict): The parameters messages (List[ModelMessage]): The model messages tokenizer (Any): The tokenizer of model, in huggingface chat model, we can create the prompt by tokenizer prompt_template (str, optional): The prompt template. Defaults to None. convert_to_compatible_format (bool, optional): Whether to convert to compatible format. Defaults to False. Returns: Optional[str]: The string prompt """ return None def get_prompt_with_template( self, params: Dict, messages: List[ModelMessage], model_name: str, model_path: str, model_context: Dict, prompt_template: str = None, ): convert_to_compatible_format = params.get("convert_to_compatible_format") conv: ConversationAdapter = self.get_default_conv_template( model_name, model_path ) if prompt_template: logger.info(f"Use prompt template {prompt_template} from config") conv = get_conv_template(prompt_template) if not conv or not messages: # Nothing to do logger.info( f"No conv from model_path {model_path} or no messages in params, {self}" ) return None, None, None conv = conv.copy() if convert_to_compatible_format: # In old version, we will convert the messages to compatible format conv = self._set_conv_converted_messages(conv, messages) else: # In new version, we will use the messages directly conv = self._set_conv_messages(conv, messages) # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) new_prompt = conv.get_prompt() return new_prompt, conv.stop_str, conv.stop_token_ids def _set_conv_messages( self, conv: ConversationAdapter, messages: List[ModelMessage] ) -> ConversationAdapter: """Set the messages to the conversation template Args: conv (ConversationAdapter): The conversation template messages (List[ModelMessage]): The model messages Returns: ConversationAdapter: The conversation template with messages """ system_messages = [] for message in messages: if isinstance(message, ModelMessage): role = message.role content = message.content elif isinstance(message, dict): role = message["role"] content = message["content"] else: raise ValueError(f"Invalid message type: {message}") if role == ModelMessageRoleType.SYSTEM: system_messages.append(content) elif role == ModelMessageRoleType.HUMAN: conv.append_message(conv.roles[0], content) elif role == ModelMessageRoleType.AI: conv.append_message(conv.roles[1], content) else: raise ValueError(f"Unknown role: {role}") if len(system_messages) > 1: raise ValueError( f"Your system messages have more than one message: {system_messages}" ) if system_messages: conv.set_system_message(system_messages[0]) return conv def _set_conv_converted_messages( self, conv: ConversationAdapter, messages: List[ModelMessage] ) -> ConversationAdapter: """Set the messages to the conversation template In the old version, we will convert the messages to compatible format. This method will be deprecated in the future. Args: conv (ConversationAdapter): The conversation template messages (List[ModelMessage]): The model messages Returns: ConversationAdapter: The conversation template with messages """ system_messages = [] user_messages = [] ai_messages = [] for message in messages: if isinstance(message, ModelMessage): role = message.role content = message.content elif isinstance(message, dict): role = message["role"] content = message["content"] else: raise ValueError(f"Invalid message type: {message}") if role == ModelMessageRoleType.SYSTEM: # Support for multiple system messages system_messages.append(content) elif role == ModelMessageRoleType.HUMAN: user_messages.append(content) elif role == ModelMessageRoleType.AI: ai_messages.append(content) else: raise ValueError(f"Unknown role: {role}") can_use_systems: [] = [] if system_messages: if len(system_messages) > 1: # Compatible with dbgpt complex scenarios, the last system will protect more complete information # entered by the current user user_messages[-1] = system_messages[-1] can_use_systems = system_messages[:-1] else: can_use_systems = system_messages for i in range(len(user_messages)): conv.append_message(conv.roles[0], user_messages[i]) if i < len(ai_messages): conv.append_message(conv.roles[1], ai_messages[i]) # TODO join all system messages may not be a good idea conv.set_system_message("".join(can_use_systems)) return conv def apply_conv_template(self) -> bool: return self.model_type() != ModelType.PROXY def model_adaptation( self, params: Dict, model_name: str, model_path: str, tokenizer: Any, prompt_template: str = None, ) -> Tuple[Dict, Dict]: """Params adaptation""" messages = params.get("messages") convert_to_compatible_format = params.get("convert_to_compatible_format") message_version = params.get("version", "v2").lower() logger.info(f"Message version is {message_version}") if convert_to_compatible_format is None: # Support convert messages to compatible format when message version is v1 convert_to_compatible_format = message_version == "v1" # Save to params params["convert_to_compatible_format"] = convert_to_compatible_format # Some model context to dbgpt server model_context = { "prompt_echo_len_char": -1, "has_format_prompt": False, "echo": params.get("echo", True), } if messages: # Dict message to ModelMessage messages = [ m if isinstance(m, ModelMessage) else ModelMessage(**m) for m in messages ] params["messages"] = messages params["string_prompt"] = ModelMessage.messages_to_string(messages) if not self.apply_conv_template(): # No need to apply conversation template, now for proxy LLM return params, model_context new_prompt = self.get_str_prompt( params, messages, tokenizer, prompt_template, convert_to_compatible_format ) conv_stop_str, conv_stop_token_ids = None, None if not new_prompt: ( new_prompt, conv_stop_str, conv_stop_token_ids, ) = self.get_prompt_with_template( params, messages, model_name, model_path, model_context, prompt_template ) if not new_prompt: return params, model_context # Overwrite the original prompt # TODO remote bos token and eos token from tokenizer_config.json of model prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) model_context["prompt_echo_len_char"] = prompt_echo_len_char model_context["has_format_prompt"] = True params["prompt"] = new_prompt custom_stop = params.get("stop") custom_stop_token_ids = params.get("stop_token_ids") # Prefer the value passed in from the input parameter params["stop"] = custom_stop or conv_stop_str params["stop_token_ids"] = custom_stop_token_ids or conv_stop_token_ids return params, model_context class AdapterEntry: """The entry of model adapter""" def __init__( self, model_adapter: LLMModelAdapter, match_funcs: List[Callable[[str, str, str], bool]] = None, ): self.model_adapter = model_adapter self.match_funcs = match_funcs or [] model_adapters: List[AdapterEntry] = [] def register_model_adapter( model_adapter_cls: Type[LLMModelAdapter], match_funcs: List[Callable[[str, str, str], bool]] = None, ) -> None: """Register a model adapter. Args: model_adapter_cls (Type[LLMModelAdapter]): The model adapter class. match_funcs (List[Callable[[str, str, str], bool]], optional): The match functions. Defaults to None. """ model_adapters.append(AdapterEntry(model_adapter_cls(), match_funcs)) def get_model_adapter( model_type: str, model_name: str, model_path: str, conv_factory: Optional[ConversationAdapterFactory] = None, ) -> Optional[LLMModelAdapter]: """Get a model adapter. Args: model_type (str): The type of the model. model_name (str): The name of the model. model_path (str): The path of the model. conv_factory (Optional[ConversationAdapterFactory], optional): The conversation factory. Defaults to None. Returns: Optional[LLMModelAdapter]: The model adapter. """ adapter = None # First find adapter by model_name for adapter_entry in model_adapters[::-1]: if adapter_entry.model_adapter.match(model_type, model_name, None): adapter = adapter_entry.model_adapter break for adapter_entry in model_adapters[::-1]: if adapter_entry.model_adapter.match(model_type, None, model_path): adapter = adapter_entry.model_adapter break if adapter: new_adapter = adapter.new_adapter() new_adapter.model_name = model_name new_adapter.model_path = model_path if conv_factory: new_adapter.conv_factory = conv_factory return new_adapter return None