Compare commits

...

11 Commits

Author SHA1 Message Date
Dev 2049
c353ba0e53 clean 2023-05-22 17:27:56 -07:00
Dev 2049
6629e87c10 cr 2023-05-22 17:19:12 -07:00
Harrison Chase
bd04d408d5 Merge branch 'SalehHindi-serialize_chat_template' into harrison/serialize-chat 2023-05-15 15:55:03 -07:00
Harrison Chase
829847decc stash 2023-05-15 15:54:23 -07:00
Harrison Chase
e8ac6c5134 Merge branch 'serialize_chat_template' of github.com:SalehHindi/langchain into SalehHindi-serialize_chat_template 2023-05-15 15:23:26 -07:00
Saleh Hindi
0bde7b73d1 Merge branch 'master' into serialize_chat_template 2023-05-15 11:35:44 -04:00
Saleh Hindi
9de19bbe9d Add role to HumanMessagePromptTemplate AIMessagePromptTemplate SystemMessagePromptTemplate 2023-05-09 15:44:15 -04:00
Saleh Hindi
5811e0c6eb Run make lint 2023-05-08 21:49:12 -04:00
Saleh Hindi
dbd8588148 Reformat using make format 2023-05-08 21:44:00 -04:00
Saleh Hindi
97cb07b910 Merge branch 'master' of https://github.com/SalehHindi/langchain into serialize_chat_template 2023-05-08 21:34:46 -04:00
Saleh Hindi
8aa7e750be Save chatTemplates with roles attached to the messages. Load chatTemplate messages based on role. 2023-05-08 21:33:21 -04:00
15 changed files with 142 additions and 35 deletions

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
) )
@@ -160,7 +159,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
self.session.execute( self.session.execute(
"""INSERT INTO message_store """INSERT INTO message_store
(id, session_id, history) VALUES (%s, %s, %s);""", (id, session_id, history) VALUES (%s, %s, %s);""",
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))), (uuid.uuid4(), self.session_id, json.dumps(message.dict())),
) )
except (Unavailable, WriteTimeout, WriteFailure) as error: except (Unavailable, WriteTimeout, WriteFailure) as error:
logger.error("Unable to write chat history messages to cassandra") logger.error("Unable to write chat history messages to cassandra")

View File

@@ -6,7 +6,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
messages_to_dict, messages_to_dict,
) )
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
messages = messages_to_dict(self.messages) messages = messages_to_dict(self.messages)
_message = _message_to_dict(message) _message = message.dict()
messages.append(_message) messages.append(_message)
try: try:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
) )
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
self.collection.insert_one( self.collection.insert_one(
{ {
"SessionId": self.session_id, "SessionId": self.session_id,
"History": json.dumps(_message_to_dict(message)), "History": json.dumps(message.dict()),
} }
) )
except errors.WriteError as err: except errors.WriteError as err:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
) )
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format( query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
sql.Identifier(self.table_name) sql.Identifier(self.table_name)
) )
self.cursor.execute( self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
query, (self.session_id, json.dumps(_message_to_dict(message)))
)
self.connection.commit() self.connection.commit()
def clear(self) -> None: def clear(self) -> None:

View File

@@ -7,7 +7,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
) )
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None: def append(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis""" """Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) self.redis_client.lpush(self.key, json.dumps(message.dict()))
if self.ttl: if self.ttl:
self.redis_client.expire(self.key, self.ttl) self.redis_client.expire(self.key, self.ttl)

View File

@@ -11,7 +11,6 @@ from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
BaseMessage, BaseMessage,
HumanMessage, HumanMessage,
_message_to_dict,
messages_from_dict, messages_from_dict,
) )
@@ -70,7 +69,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
def append(self, message: BaseMessage) -> None: def append(self, message: BaseMessage) -> None:
"""Append the message to the record in db""" """Append the message to the record in db"""
with self.Session() as session: with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message)) jsonstr = json.dumps(message.dict())
session.add(self.Message(session_id=self.session_id, message=jsonstr)) session.add(self.Message(session_id=self.session_id, message=jsonstr))
session.commit() session.commit()

View File

@@ -1,10 +1,12 @@
"""Chat prompt template.""" """Chat prompt template."""
from __future__ import annotations from __future__ import annotations
import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
import yaml
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.memory.buffer import get_buffer_string from langchain.memory.buffer import get_buffer_string
@@ -95,6 +97,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
return self.prompt.input_variables return self.prompt.input_variables
@property
@abstractmethod
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
def dict(self, *args: Any, **kwargs: Any) -> dict:
result = super().dict(*args, **kwargs)
result["_type"] = self._type
return result
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str role: str
@@ -105,24 +117,48 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
content=text, role=self.role, additional_kwargs=self.additional_kwargs content=text, role=self.role, additional_kwargs=self.additional_kwargs
) )
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "chat-message-prompt-template"
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "human-message-prompt-template"
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "ai"
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs) return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "ai-message-prompt-template"
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
role: str = "system"
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
@property
def _type(self) -> str:
"""The type of MessagePromptTemplate."""
return "system-message-prompt-template"
class ChatPromptValue(PromptValue): class ChatPromptValue(PromptValue):
messages: List[BaseMessage] messages: List[BaseMessage]
@@ -217,7 +253,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
@property @property
def _prompt_type(self) -> str: def _prompt_type(self) -> str:
raise NotImplementedError """Return the prompt type key."""
return "chat_prompt"
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
raise NotImplementedError if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")

View File

@@ -9,8 +9,16 @@ import yaml
from langchain.output_parsers.regex import RegexParser from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseMessage, message_from_dict
from langchain.utilities.loading import try_load_from_hub from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
@@ -114,6 +122,35 @@ def _load_prompt(config: dict) -> PromptTemplate:
return PromptTemplate(**config) return PromptTemplate(**config)
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
"""Load the prompt template from config."""
# Load the template from disk if necessary.
config = _load_template("template", config)
config = _load_output_parser(config)
messages = []
for message in config["messages"]:
_type = message.pop("_type")
if _type == "human-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message: Union[
BaseMessagePromptTemplate, BaseMessage
] = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "ai-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "system-message-prompt-template":
prompt = load_prompt_from_config(message.pop("prompt"))
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
elif _type == "base-message":
_message = message_from_dict(message)
else: # role == system
raise ValueError
messages.append(_message)
return ChatPromptTemplate.from_messages(messages)
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
"""Unified method for loading a prompt from LangChainHub or local fs.""" """Unified method for loading a prompt from LangChainHub or local fs."""
if hub_result := try_load_from_hub( if hub_result := try_load_from_hub(
@@ -158,6 +195,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
type_to_loader_dict = { type_to_loader_dict = {
"prompt": _load_prompt, "prompt": _load_prompt,
"chat_prompt": _load_chat_prompt,
"few_shot": _load_few_shot_prompt, "few_shot": _load_few_shot_prompt,
# "few_shot_with_templates": _load_few_shot_with_templates_prompt, # "few_shot_with_templates": _load_few_shot_with_templates_prompt,
} }

View File

@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
def type(self) -> str: def type(self) -> str:
"""Type of the message, used for serialization.""" """Type of the message, used for serialization."""
def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
return {
"type": self.type,
"data": super().dict(*args, **kwargs),
"_type": "base-message",
}
class HumanMessage(BaseMessage): class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human.""" """Type of message that is spoken by the human."""
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
return "chat" return "chat"
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
return [_message_to_dict(m) for m in messages] return [m.dict() for m in messages]
def _message_from_dict(message: dict) -> BaseMessage: def message_from_dict(message: dict) -> BaseMessage:
_type = message["type"] _type = message["type"]
if _type == "human": if _type == "human":
return HumanMessage(**message["data"]) return HumanMessage(**message["data"])
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages] return [message_from_dict(m) for m in messages]
class ChatGeneration(Generation): class ChatGeneration(Generation):

View File

@@ -5,7 +5,6 @@ from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories.cassandra import ( from langchain.memory.chat_message_histories.cassandra import (
CassandraChatMessageHistory, CassandraChatMessageHistory,
) )
from langchain.schema import _message_to_dict
# Replace these with your cassandra contact points # Replace these with your cassandra contact points
contact_points = ( contact_points = (
@@ -31,7 +30,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json # get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json assert "This is me, the human" in messages_json

View File

@@ -3,7 +3,6 @@ import os
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
from langchain.schema import _message_to_dict
# Replace these with your Azure Cosmos DB endpoint and key # Replace these with your Azure Cosmos DB endpoint and key
endpoint = os.environ["COSMOS_DB_ENDPOINT"] endpoint = os.environ["COSMOS_DB_ENDPOINT"]
@@ -33,7 +32,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json # get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json assert "This is me, the human" in messages_json

View File

@@ -2,7 +2,6 @@ import json
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
from langchain.schema import _message_to_dict
def test_memory_with_message_store() -> None: def test_memory_with_message_store() -> None:
@@ -32,7 +31,7 @@ def test_memory_with_message_store() -> None:
memory_key="baz", chat_memory=message_history, return_messages=True memory_key="baz", chat_memory=message_history, return_messages=True
) )
messages = memory.chat_memory.messages messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json assert "This is me, the human" in messages_json

View File

@@ -3,7 +3,6 @@ import os
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
from langchain.schema import _message_to_dict
# Replace these with your mongodb connection string # Replace these with your mongodb connection string
connection_string = os.environ["MONGODB_CONNECTION_STRING"] connection_string = os.environ["MONGODB_CONNECTION_STRING"]
@@ -25,7 +24,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json # get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json assert "This is me, the human" in messages_json

View File

@@ -2,7 +2,6 @@ import json
from langchain.memory import ConversationBufferMemory from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import RedisChatMessageHistory from langchain.memory.chat_message_histories import RedisChatMessageHistory
from langchain.schema import _message_to_dict
def test_memory_with_message_store() -> None: def test_memory_with_message_store() -> None:
@@ -21,7 +20,7 @@ def test_memory_with_message_store() -> None:
# get the message history from the memory store and turn it into a json # get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) messages_json = json.dumps([msg.dict() for msg in messages])
assert "This is me, the AI" in messages_json assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json assert "This is me, the human" in messages_json

View File

@@ -3,12 +3,19 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Iterator from typing import Iterator, List, Sequence
from langchain.output_parsers import RegexParser from langchain.output_parsers import RegexParser
from langchain.prompts.chat import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.loading import load_prompt from langchain.prompts.loading import load_prompt
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
@contextmanager @contextmanager
@@ -72,6 +79,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
assert loaded_prompt == few_shot_prompt assert loaded_prompt == few_shot_prompt
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
"""Test equality when saving and loading a chat prompt."""
message_list: List[Sequence] = [
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
[
SystemMessage(content="hi"),
SystemMessagePromptTemplate.from_template("{foo}"),
],
]
for messages in message_list:
simple_prompt = ChatPromptTemplate.from_messages(messages)
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
assert loaded_prompt == simple_prompt
def test_loading_with_template_as_file() -> None: def test_loading_with_template_as_file() -> None:
"""Test loading when the template is a file.""" """Test loading when the template is a file."""
with change_directory(): with change_directory():