mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-25 19:44:59 +00:00
多场景对架构一期
This commit is contained in:
51
pilot/prompts/base.py
Normal file
51
pilot/prompts/base.py
Normal file
@@ -0,0 +1,51 @@
|
||||
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from pilot.scene.base_message import BaseMessage,HumanMessage,AIMessage, SystemMessage
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
) -> str:
|
||||
"""Get buffer string of messages."""
|
||||
string_messages = []
|
||||
for m in messages:
|
||||
if isinstance(m, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(m, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(m, SystemMessage):
|
||||
role = "System"
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {m}")
|
||||
string_messages.append(f"{role}: {m.content}")
|
||||
return "\n".join(string_messages)
|
||||
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
|
||||
@abstractmethod
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
messages: List[BaseMessage]
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Return prompt as string."""
|
||||
return get_buffer_string(self.messages)
|
||||
|
||||
def to_messages(self) -> List[BaseMessage]:
|
||||
"""Return prompt as messages."""
|
||||
return self.messages
|
Reference in New Issue
Block a user