This commit is contained in:
Harrison Chase
2023-05-15 15:54:23 -07:00
parent e8ac6c5134
commit 829847decc
9 changed files with 96 additions and 50 deletions

View File

@@ -6,7 +6,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
messages_to_dict,
)
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError
messages = messages_to_dict(self.messages)
_message = _message_to_dict(message)
_message = message.dict()
messages.append(_message)
try:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.insert_one(
{
"SessionId": self.session_id,
"History": json.dumps(_message_to_dict(message)),
"History": json.dumps(message.dict()),
}
)
except errors.WriteError as err:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
sql.Identifier(self.table_name)
)
self.cursor.execute(
query, (self.session_id, json.dumps(_message_to_dict(message)))
)
self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
self.connection.commit()
def clear(self) -> None:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(self.key, json.dumps(message.dict()))
if self.ttl:
self.redis_client.expire(self.key, self.ttl)

View File

@@ -10,7 +10,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -69,7 +68,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
jsonstr = json.dumps(message.dict())
session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.commit()

View File

@@ -2,11 +2,11 @@
from __future__ import annotations
import json
import yaml
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
import yaml
from pydantic import BaseModel, Field
from langchain.memory.buffer import get_buffer_string
@@ -87,6 +87,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def input_variables(self) -> List[str]:
return self.prompt.input_variables
@property
@abstractmethod
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
def dict(self, *args, **kwargs):
result = super().dict(*args, **kwargs)
result["_type"] = self._type
return result
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str
@@ -97,14 +107,22 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "chat-message-prompt-template"
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "human"
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "human-message-prompt-template"
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "ai"
@@ -113,6 +131,11 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "ai-message-prompt-template"
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "system"
@@ -121,6 +144,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "system-message-prompt-template"
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
@@ -149,15 +177,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC):
class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
input_variables: List[str]
messages: List[
Union[
BaseMessagePromptTemplate,
BaseMessage,
AIMessagePromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
]
]
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
@@ -224,19 +244,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "chatPrompt"
return "chat_prompt"
def dict(self, **kwargs: Any) -> Dict:
prompt_dict = json.loads(self.json())
prompt_dict["_type"] = self._prompt_type
for i, message in enumerate(self.messages):
if isinstance(message, SystemMessagePromptTemplate):
prompt_dict["messages"][i]["prompt"]["role"] = "system"
elif isinstance(message, HumanMessagePromptTemplate):
prompt_dict["messages"][i]["prompt"]["role"] = "human"
elif isinstance(message, AIMessagePromptTemplate):
prompt_dict["messages"][i]["prompt"]["role"] = "ai"
prompt_dict = super().dict(**kwargs)
# for i, message in enumerate(self.messages):
# if isinstance(message, SystemMessagePromptTemplate):
# prompt_dict["messages"][i]["prompt"]["role"] = "system"
# elif isinstance(message, HumanMessagePromptTemplate):
# prompt_dict["messages"][i]["prompt"]["role"] = "human"
# elif isinstance(message, AIMessagePromptTemplate):
# prompt_dict["messages"][i]["prompt"]["role"] = "ai"
return prompt_dict

View File

@@ -17,6 +17,7 @@ from langchain.prompts.chat import (
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import message_from_dict
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
@@ -129,14 +130,21 @@ def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
messages = []
for message in config["messages"]:
role = message["prompt"].get("role")
template = message["prompt"]["template"]
if role == "human":
messages.append(HumanMessagePromptTemplate.from_template(template))
elif role == "ai":
messages.append(AIMessagePromptTemplate.from_template(template))
_type = message.pop("_type")
if _type == "human-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "ai-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "system-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "base-message":
_message = message_from_dict(message)
else: # role == system
messages.append(SystemMessagePromptTemplate.from_template(template))
raise ValueError
messages.append(_message)
return ChatPromptTemplate.from_messages(messages)
@@ -185,7 +193,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
type_to_loader_dict = {
"prompt": _load_prompt,
"chatPrompt": _load_chat_prompt,
"chat_prompt": _load_chat_prompt,
"few_shot": _load_few_shot_prompt,
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
}

View File

@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
def type(self) -> str:
"""Type of the message, used for serialization."""
def dict(self, *args, **kwargs):
return {
"type": self.type,
"data": super().dict(*args, **kwargs),
"_type": "base-message",
}
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages]
return [m.dict() for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
def message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
return [message_from_dict(m) for m in messages]
class ChatGeneration(Generation):

View File

@@ -5,9 +5,16 @@ from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
@contextmanager
@@ -71,6 +78,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
assert loaded_prompt == few_shot_prompt
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
"""Test equality when saving and loading a chat prompt."""
message_list = [
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
[
SystemMessage(content="hi"),
SystemMessagePromptTemplate.from_template("{foo}"),
],
]
for messages in message_list:
simple_prompt = ChatPromptTemplate.from_messages(messages)
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
assert loaded_prompt == simple_prompt
def test_loading_with_template_as_file() -> None:
"""Test loading when the template is a file."""
with change_directory():