mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 20:39:44 +00:00
feat(model): Support Mixtral-8x7B (#959)
This commit is contained in:
130
dbgpt/model/adapter/template.py
Normal file
130
dbgpt/model/adapter/template.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, Union, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.conversation import Conversation
|
||||
|
||||
|
||||
class PromptType(str, Enum):
|
||||
"""Prompt type."""
|
||||
|
||||
FSCHAT: str = "fschat"
|
||||
DBGPT: str = "dbgpt"
|
||||
|
||||
|
||||
class ConversationAdapter(ABC):
|
||||
"""The conversation adapter."""
|
||||
|
||||
@property
|
||||
def prompt_type(self) -> PromptType:
|
||||
return PromptType.FSCHAT
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def roles(self) -> Tuple[str]:
|
||||
"""Get the roles of the conversation.
|
||||
|
||||
Returns:
|
||||
Tuple[str]: The roles of the conversation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def sep(self) -> Optional[str]:
|
||||
"""Get the separator between messages."""
|
||||
return "\n"
|
||||
|
||||
@property
|
||||
def stop_str(self) -> Optional[Union[str, List[str]]]:
|
||||
"""Get the stop criteria."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def stop_token_ids(self) -> Optional[List[int]]:
|
||||
"""Stops generation if meeting any token in this list"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def get_prompt(self) -> str:
|
||||
"""Get the prompt string.
|
||||
|
||||
Returns:
|
||||
str: The prompt string.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def set_system_message(self, system_message: str) -> None:
|
||||
"""Set the system message."""
|
||||
|
||||
@abstractmethod
|
||||
def append_message(self, role: str, message: str) -> None:
|
||||
"""Append a new message.
|
||||
Args:
|
||||
role (str): The role of the message.
|
||||
message (str): The message content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update_last_message(self, message: str) -> None:
|
||||
"""Update the last output.
|
||||
|
||||
The last message is typically set to be None when constructing the prompt,
|
||||
so we need to update it in-place after getting the response from a model.
|
||||
|
||||
Args:
|
||||
message (str): The message content.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self) -> "ConversationAdapter":
|
||||
"""Copy the conversation."""
|
||||
|
||||
|
||||
class ConversationAdapterFactory(ABC):
|
||||
"""The conversation adapter factory."""
|
||||
|
||||
def get_by_name(
|
||||
self,
|
||||
template_name: str,
|
||||
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
|
||||
) -> ConversationAdapter:
|
||||
"""Get a conversation adapter by name.
|
||||
|
||||
Args:
|
||||
template_name (str): The name of the template.
|
||||
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation adapter.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
||||
"""Get a conversation adapter by model.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model.
|
||||
model_path (str): The path of the model.
|
||||
|
||||
Returns:
|
||||
ConversationAdapter: The conversation adapter.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> ConversationAdapter:
|
||||
"""Get a conversation template.
|
||||
|
||||
Args:
|
||||
name (str): The name of the template.
|
||||
|
||||
Just return the fastchat conversation template for now.
|
||||
# TODO: More templates should be supported.
|
||||
Returns:
|
||||
Conversation: The conversation template.
|
||||
"""
|
||||
from fastchat.conversation import get_conv_template
|
||||
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
|
||||
|
||||
conv_template = get_conv_template(name)
|
||||
return FschatConversationAdapter(conv_template)
|
Reference in New Issue
Block a user