mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 13:06:53 +00:00
132 lines
3.5 KiB
Python
132 lines
3.5 KiB
Python
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
|
|
if TYPE_CHECKING:
|
|
from fastchat.conversation import Conversation
|
|
|
|
|
|
class PromptType(str, Enum):
|
|
"""Prompt type."""
|
|
|
|
FSCHAT: str = "fschat"
|
|
DBGPT: str = "dbgpt"
|
|
|
|
|
|
class ConversationAdapter(ABC):
|
|
"""The conversation adapter."""
|
|
|
|
@property
|
|
def prompt_type(self) -> PromptType:
|
|
return PromptType.FSCHAT
|
|
|
|
@property
|
|
@abstractmethod
|
|
def roles(self) -> Tuple[str]:
|
|
"""Get the roles of the conversation.
|
|
|
|
Returns:
|
|
Tuple[str]: The roles of the conversation.
|
|
"""
|
|
|
|
@property
|
|
def sep(self) -> Optional[str]:
|
|
"""Get the separator between messages."""
|
|
return "\n"
|
|
|
|
@property
|
|
def stop_str(self) -> Optional[Union[str, List[str]]]:
|
|
"""Get the stop criteria."""
|
|
return None
|
|
|
|
@property
|
|
def stop_token_ids(self) -> Optional[List[int]]:
|
|
"""Stops generation if meeting any token in this list"""
|
|
return None
|
|
|
|
@abstractmethod
|
|
def get_prompt(self) -> str:
|
|
"""Get the prompt string.
|
|
|
|
Returns:
|
|
str: The prompt string.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def set_system_message(self, system_message: str) -> None:
|
|
"""Set the system message."""
|
|
|
|
@abstractmethod
|
|
def append_message(self, role: str, message: str) -> None:
|
|
"""Append a new message.
|
|
Args:
|
|
role (str): The role of the message.
|
|
message (str): The message content.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def update_last_message(self, message: str) -> None:
|
|
"""Update the last output.
|
|
|
|
The last message is typically set to be None when constructing the prompt,
|
|
so we need to update it in-place after getting the response from a model.
|
|
|
|
Args:
|
|
message (str): The message content.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def copy(self) -> "ConversationAdapter":
|
|
"""Copy the conversation."""
|
|
|
|
|
|
class ConversationAdapterFactory(ABC):
|
|
"""The conversation adapter factory."""
|
|
|
|
def get_by_name(
|
|
self,
|
|
template_name: str,
|
|
prompt_template_type: Optional[PromptType] = PromptType.FSCHAT,
|
|
) -> ConversationAdapter:
|
|
"""Get a conversation adapter by name.
|
|
|
|
Args:
|
|
template_name (str): The name of the template.
|
|
prompt_template_type (Optional[PromptType]): The type of the prompt template, default to be FSCHAT.
|
|
|
|
Returns:
|
|
ConversationAdapter: The conversation adapter.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
def get_by_model(self, model_name: str, model_path: str) -> ConversationAdapter:
|
|
"""Get a conversation adapter by model.
|
|
|
|
Args:
|
|
model_name (str): The name of the model.
|
|
model_path (str): The path of the model.
|
|
|
|
Returns:
|
|
ConversationAdapter: The conversation adapter.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
|
|
def get_conv_template(name: str) -> ConversationAdapter:
|
|
"""Get a conversation template.
|
|
|
|
Args:
|
|
name (str): The name of the template.
|
|
|
|
Just return the fastchat conversation template for now.
|
|
# TODO: More templates should be supported.
|
|
Returns:
|
|
Conversation: The conversation template.
|
|
"""
|
|
from fastchat.conversation import get_conv_template
|
|
|
|
from dbgpt.model.adapter.fschat_adapter import FschatConversationAdapter
|
|
|
|
conv_template = get_conv_template(name)
|
|
return FschatConversationAdapter(conv_template)
|