mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-25 23:04:03 +00:00
149 lines
3.6 KiB
Python
149 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
from abc import ABC, abstractmethod
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Generic,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Sequence,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
from pydantic import BaseModel, Extra, Field, root_validator
|
|
|
|
|
|
class PromptValue(BaseModel, ABC):
|
|
@abstractmethod
|
|
def to_string(self) -> str:
|
|
"""Return prompt as string."""
|
|
|
|
@abstractmethod
|
|
def to_messages(self) -> List[BaseMessage]:
|
|
"""Return prompt as messages."""
|
|
|
|
|
|
class BaseMessage(BaseModel):
|
|
"""Message object."""
|
|
|
|
content: str
|
|
additional_kwargs: dict = Field(default_factory=dict)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
|
|
|
|
class HumanMessage(BaseMessage):
|
|
"""Type of message that is spoken by the human."""
|
|
|
|
example: bool = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "human"
|
|
|
|
|
|
class AIMessage(BaseMessage):
|
|
"""Type of message that is spoken by the AI."""
|
|
|
|
example: bool = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "ai"
|
|
|
|
|
|
class ViewMessage(BaseMessage):
|
|
"""Type of message that is spoken by the AI."""
|
|
|
|
example: bool = False
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "view"
|
|
|
|
|
|
class SystemMessage(BaseMessage):
|
|
"""Type of message that is a system message."""
|
|
|
|
@property
|
|
def type(self) -> str:
|
|
"""Type of the message, used for serialization."""
|
|
return "system"
|
|
|
|
|
|
class Generation(BaseModel):
|
|
"""Output of a single generation."""
|
|
|
|
text: str
|
|
"""Generated text output."""
|
|
|
|
generation_info: Optional[Dict[str, Any]] = None
|
|
"""Raw generation info response from the provider"""
|
|
"""May include things like reason for finishing (e.g. in OpenAI)"""
|
|
|
|
|
|
class ChatGeneration(Generation):
|
|
"""Output of a single generation."""
|
|
|
|
text = ""
|
|
message: BaseMessage
|
|
|
|
@root_validator
|
|
def set_text(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
values["text"] = values["message"].content
|
|
return values
|
|
|
|
|
|
class ChatResult(BaseModel):
|
|
"""Class that contains all relevant information for a Chat Result."""
|
|
|
|
generations: List[ChatGeneration]
|
|
"""List of the things generated."""
|
|
llm_output: Optional[dict] = None
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
class LLMResult(BaseModel):
|
|
"""Class that contains all relevant information for an LLM Result."""
|
|
|
|
generations: List[List[Generation]]
|
|
"""List of the things generated. This is List[List[]] because
|
|
each input could have multiple generations."""
|
|
llm_output: Optional[dict] = None
|
|
"""For arbitrary LLM provider specific output."""
|
|
|
|
|
|
def _message_to_dict(message: BaseMessage) -> dict:
|
|
return {"type": message.type, "data": message.dict()}
|
|
|
|
|
|
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
|
return [_message_to_dict(m) for m in messages]
|
|
|
|
|
|
def _message_from_dict(message: dict) -> BaseMessage:
|
|
_type = message["type"]
|
|
if _type == "human":
|
|
return HumanMessage(**message["data"])
|
|
elif _type == "ai":
|
|
return AIMessage(**message["data"])
|
|
elif _type == "system":
|
|
return SystemMessage(**message["data"])
|
|
elif _type == "view":
|
|
return ViewMessage(**message["data"])
|
|
else:
|
|
raise ValueError(f"Got unexpected type: {_type}")
|
|
|
|
|
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
|
return [_message_from_dict(m) for m in messages]
|