mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 19:44:59 +00:00
50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
import json
|
|
from abc import ABC, abstractmethod
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
|
|
|
import yaml
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
|
|
|
|
|
|
def get_buffer_string(
|
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
|
) -> str:
|
|
"""Get buffer string of messages."""
|
|
string_messages = []
|
|
for m in messages:
|
|
if isinstance(m, HumanMessage):
|
|
role = human_prefix
|
|
elif isinstance(m, AIMessage):
|
|
role = ai_prefix
|
|
elif isinstance(m, SystemMessage):
|
|
role = "System"
|
|
else:
|
|
raise ValueError(f"Got unsupported message type: {m}")
|
|
string_messages.append(f"{role}: {m.content}")
|
|
return "\n".join(string_messages)
|
|
|
|
|
|
class PromptValue(BaseModel, ABC):
|
|
@abstractmethod
|
|
def to_string(self) -> str:
|
|
"""Return prompt as string."""
|
|
|
|
@abstractmethod
|
|
def to_messages(self) -> List[BaseMessage]:
|
|
"""Return prompt as messages."""
|
|
|
|
|
|
class ChatPromptValue(PromptValue):
|
|
messages: List[BaseMessage]
|
|
|
|
def to_string(self) -> str:
|
|
"""Return prompt as string."""
|
|
return get_buffer_string(self.messages)
|
|
|
|
def to_messages(self) -> List[BaseMessage]:
|
|
"""Return prompt as messages."""
|
|
return self.messages
|