mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-21 20:13:39 +00:00
save messages (#1653)
@yakigac this is my alternative to https://github.com/hwchase17/langchain/pull/1648 - thoughts?
This commit is contained in:
parent
63aa28e2a6
commit
362586fe8b
@ -40,24 +40,75 @@ class BaseMessage(BaseModel):
|
|||||||
content: str
|
content: str
|
||||||
additional_kwargs: dict = Field(default_factory=dict)
|
additional_kwargs: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "human"
|
||||||
|
|
||||||
|
|
||||||
class AIMessage(BaseMessage):
|
class AIMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the AI."""
|
"""Type of message that is spoken by the AI."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "ai"
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseMessage):
|
class SystemMessage(BaseMessage):
|
||||||
"""Type of message that is a system message."""
|
"""Type of message that is a system message."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "system"
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseMessage):
|
class ChatMessage(BaseMessage):
|
||||||
"""Type of message with arbitrary speaker."""
|
"""Type of message with arbitrary speaker."""
|
||||||
|
|
||||||
role: str
|
role: str
|
||||||
|
|
||||||
|
@property
|
||||||
|
def type(self) -> str:
|
||||||
|
"""Type of the message, used for serialization."""
|
||||||
|
return "chat"
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_json(message: BaseMessage) -> dict:
|
||||||
|
return {"type": message.type, "data": message.dict()}
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_json(messages: List[BaseMessage]) -> List[dict]:
|
||||||
|
return [_message_to_json(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
|
def _message_from_json(message: dict) -> BaseMessage:
|
||||||
|
_type = message["type"]
|
||||||
|
if _type == "human":
|
||||||
|
return HumanMessage(**message["data"])
|
||||||
|
elif _type == "ai":
|
||||||
|
return AIMessage(**message["data"])
|
||||||
|
elif _type == "system":
|
||||||
|
return SystemMessage(**message["data"])
|
||||||
|
elif _type == "chat":
|
||||||
|
return ChatMessage(**message["data"])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unexpected type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def messages_from_json(messages: List[dict]) -> List[BaseMessage]:
|
||||||
|
return [_message_from_json(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
"""Output of a single generation."""
|
"""Output of a single generation."""
|
||||||
|
Loading…
Reference in New Issue
Block a user