DB-GPT/dbgpt/model/adapter/template.py
2024-01-10 10:39:04 +08:00

132 lines
3.5 KiB
Python

from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
from fastchat.conversation import Conversation
class PromptType(str, Enum):
"""Prompt type."""
FSCHAT: str = "fschat"
DBGPT: str = "dbgpt"
class ConversationAdapter(ABC):
"""The conversation adapter."""
@property
def prompt_type(self) -> PromptType:
return PromptType.FSCHAT
@property
@abstractmethod
def roles(self) -> Tuple[str]:
"""Get the roles of the conversation.
Returns:
Tuple[str]: The roles of the conversation.
"""
@property
def sep(self) -> Optional[str]:
"""Get the separator between messages."""
return "\n"
@property
def stop_str(self) -> Optional[Union[str, List[str]]]:
"""Get the stop criteria."""
return None
@property
def stop_token_ids(self) -> Optional[List[int]]:
"""Stops generation if meeting any token in this list"""
return None
@abstractmethod
def get_prompt(self) -> str:
"""Get the prompt string.
Returns:
str: The prompt string.
"""
@abstractmethod
def set_system_message(self, system_message: str) -> None:
"""Set the system message."""
@abstractmethod
def append_message(self, role: str, message: str) -> None:
"""Append a new message.
Args:
role (str): The role of the message.
message (str): The message content.
"""
@abstractmethod
def update_last_message(self, message: str) -> None:
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
Args:
message (str): The message content.
"""
@abstractmethod
def copy(self) -> "ConversationAdapter":
"""Copy the conversation."""
class ConversationAdapterFactory(ABC):
"""The conversation adapter factory."""
def get_by_name(
self,
template_name: str,
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
) -> ConversationAdapter:
"""Get a conversation adapter by name.
Args:
template_name (str): The name of the template.
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
Returns:
ConversationAdapter: The conversation adapter.
"""
raise NotImplementedError()
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
"""Get a conversation adapter by model.
Args:
model_name (str): The name of the model.
model_path (str): The path of the model.
Returns:
ConversationAdapter: The conversation adapter.
"""
raise NotImplementedError()
def get_conv_template(name: str) -> ConversationAdapter:
"""Get a conversation template.
Args:
name (str): The name of the template.
Just return the fastchat conversation template for now.
# TODO: More templates should be supported.
Returns:
Conversation: The conversation template.
"""
from fastchat.conversation import get_conv_template
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
conv_template = get_conv_template(name)
return FschatConversationAdapter(conv_template)