Compare commits

...

11 Commits

Author SHA1 Message Date
Dev 2049
c353ba0e53 clean 2023-05-22 17:27:56 -07:00
Dev 2049
6629e87c10 cr 2023-05-22 17:19:12 -07:00
Harrison Chase
bd04d408d5 Merge branch 'SalehHindi-serialize_chat_template' into harrison/serialize-chat 2023-05-15 15:55:03 -07:00
Harrison Chase
829847decc stash 2023-05-15 15:54:23 -07:00
Harrison Chase
e8ac6c5134 Merge branch 'serialize_chat_template' of github.com:SalehHindi/langchain into SalehHindi-serialize_chat_template 2023-05-15 15:23:26 -07:00
Saleh Hindi
0bde7b73d1 Merge branch 'master' into serialize_chat_template 2023-05-15 11:35:44 -04:00
Saleh Hindi
9de19bbe9d Add role to HumanMessagePromptTemplate AIMessagePromptTemplate SystemMessagePromptTemplate 2023-05-09 15:44:15 -04:00
Saleh Hindi
5811e0c6eb Run make lint 2023-05-08 21:49:12 -04:00
Saleh Hindi
dbd8588148 Reformat using make format 2023-05-08 21:44:00 -04:00
Saleh Hindi
97cb07b910 Merge branch 'master' of https://github.com/SalehHindi/langchain into serialize_chat_template 2023-05-08 21:34:46 -04:00
Saleh Hindi
8aa7e750be Save chatTemplates with roles attached to the messages. Load chatTemplate messages based on role. 2023-05-08 21:33:21 -04:00
15 changed files with 142 additions and 35 deletions

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -160,7 +159,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
self.session.execute(
"""INSERT INTO message_store
(id, session_id, history) VALUES (%s, %s, %s);""",
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
(uuid.uuid4(), self.session_id, json.dumps(message.dict())),
)
except (Unavailable, WriteTimeout, WriteFailure) as error:
logger.error("Unable to write chat history messages to cassandra")

View File

@@ -6,7 +6,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
messages_to_dict,
)
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError
messages = messages_to_dict(self.messages)
_message = _message_to_dict(message)
_message = message.dict()
messages.append(_message)
try:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.insert_one(
{
"SessionId": self.session_id,
"History": json.dumps(_message_to_dict(message)),
"History": json.dumps(message.dict()),
}
)
except errors.WriteError as err:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
sql.Identifier(self.table_name)
)
self.cursor.execute(
query, (self.session_id, json.dumps(_message_to_dict(message)))
)
self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
self.connection.commit()
def clear(self) -> None:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(self.key, json.dumps(message.dict()))
if self.ttl:
self.redis_client.expire(self.key, self.ttl)

View File

@@ -11,7 +11,6 @@ from langchain.schema import (
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
@@ -70,7 +69,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
jsonstr = json.dumps(message.dict())
session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.commit()

View File

@@ -1,10 +1,12 @@
"""Chat prompt template."""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
import yaml
from pydantic import BaseModel, Field
from langchain.memory.buffer import get_buffer_string
@@ -95,6 +97,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def input_variables(self) -> List[str]:
return self.prompt.input_variables
@property
@abstractmethod
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
def dict(self, *args: Any, **kwargs: Any) -> dict:
result = super().dict(*args, **kwargs)
result["_type"] = self._type
return result
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str
@@ -105,24 +117,48 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
content=text, role=self.role, additional_kwargs=self.additional_kwargs
)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "chat-message-prompt-template"
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "human-message-prompt-template"
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "ai"
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "ai-message-prompt-template"
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "system"
def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "system-message-prompt-template"
class ChatPromptValue(PromptValue):
messages: List[BaseMessage]
@@ -217,7 +253,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property
def _prompt_type(self) -> str:
raise NotImplementedError
"""Return the prompt type key."""
return "chat_prompt"
def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")

View File

@@ -9,8 +9,16 @@ import yaml
from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseMessage, message_from_dict
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
@@ -114,6 +122,35 @@ def _load_prompt(config: dict) -> PromptTemplate:
return PromptTemplate(**config)
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load the prompt template from config."""
# Load the template from disk if necessary.
config = _load_template("template", config)
config = _load_output_parser(config)
messages = []
for message in config["messages"]:
_type = message.pop("_type")
if _type == "human-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message: Union[
BaseMessagePromptTemplate, BaseMessage
] = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "ai-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "system-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "base-message":
_message = message_from_dict(message)
else: # role == system
raise ValueError
messages.append(_message)
return ChatPromptTemplate.from_messages(messages)
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from LangChainHub or local fs."""
if hub_result := try_load_from_hub(
@@ -158,6 +195,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
type_to_loader_dict = {
"prompt": _load_prompt,
"chat_prompt": _load_chat_prompt,
"few_shot": _load_few_shot_prompt,
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
}

View File

@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
def type(self) -> str:
"""Type of the message, used for serialization."""
def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return {
"type": self.type,
"data": super().dict(*args, **kwargs),
"_type": "base-message",
}
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages]
return [m.dict() for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
def message_from_dict(message: dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
return [message_from_dict(m) for m in messages]
class ChatGeneration(Generation):

View File

@@ -5,7 +5,6 @@ from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory,
)
from langchain.schema import _message_to_dict
# Replace these with your cassandra contact points
contact_points = (
@@ -31,7 +30,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

View File

@@ -3,7 +3,6 @@ import os
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
from langchain.schema import _message_to_dict
# Replace these with your Azure Cosmos DB endpoint and key
endpoint = os.environ["COSMOS_DB_ENDPOINT"]
@@ -33,7 +32,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

View File

@@ -2,7 +2,6 @@ import json
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
from langchain.schema import _message_to_dict
def test_memory_with_message_store() -> None:
@@ -32,7 +31,7 @@ def test_memory_with_message_store() -> None:
memory_key="baz", chat_memory=message_history, return_messages=True
)
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

View File

@@ -3,7 +3,6 @@ import os
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
from langchain.schema import _message_to_dict
# Replace these with your mongodb connection string
connection_string = os.environ["MONGODB_CONNECTION_STRING"]
@@ -25,7 +24,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

View File

@@ -2,7 +2,6 @@ import json
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.schema import _message_to_dict
def test_memory_with_message_store() -> None:
@@ -21,7 +20,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

View File

@@ -3,12 +3,19 @@
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator
from typing import Iterator, List, Sequence
from langchain.output_parsers import RegexParser
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
@contextmanager
@@ -72,6 +79,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
assert loaded_prompt == few_shot_prompt
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
"""Test equality when saving and loading a chat prompt."""
message_list: List[Sequence] = [
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
[
SystemMessage(content="hi"),
SystemMessagePromptTemplate.from_template("{foo}"),
],
]
for messages in message_list:
simple_prompt = ChatPromptTemplate.from_messages(messages)
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
assert loaded_prompt == simple_prompt
def test_loading_with_template_as_file() -> None:
"""Test loading when the template is a file."""
with change_directory():