Files
DB-GPT/pilot/prompts/base.py
2023-05-25 20:39:04 +08:00

50 lines
1.4 KiB
Python

import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import BaseModel, Extra, Field, root_validator
from pilot.scene.base_message import BaseMessage, HumanMessage, AIMessage, SystemMessage
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
) -> str:
"""Get buffer string of messages."""
string_messages = []
for m in messages:
if isinstance(m, HumanMessage):
role = human_prefix
elif isinstance(m, AIMessage):
role = ai_prefix
elif isinstance(m, SystemMessage):
role = "System"
else:
raise ValueError(f"Got unsupported message type: {m}")
string_messages.append(f"{role}: {m.content}")
return "\n".join(string_messages)
class PromptValue(BaseModel, ABC):
@abstractmethod
def to_string(self) -> str:
"""Return prompt as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return self.messages