From 829847decc31fc402ff55d5bb6f79602e43e18bf Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 15 May 2023 15:54:23 -0700 Subject: [PATCH] stash --- .../memory/chat_message_histories/dynamodb.py | 3 +- .../memory/chat_message_histories/mongodb.py | 3 +- .../memory/chat_message_histories/postgres.py | 5 +- .../memory/chat_message_histories/redis.py | 3 +- .../memory/chat_message_histories/sql.py | 3 +- langchain/prompts/chat.py | 64 ++++++++++++------- langchain/prompts/loading.py | 24 ++++--- langchain/schema.py | 17 +++-- tests/unit_tests/prompts/test_loading.py | 24 +++++++ 9 files changed, 96 insertions(+), 50 deletions(-) diff --git a/langchain/memory/chat_message_histories/dynamodb.py b/langchain/memory/chat_message_histories/dynamodb.py index 413183eac0d..dc5b7c57df2 100644 --- a/langchain/memory/chat_message_histories/dynamodb.py +++ b/langchain/memory/chat_message_histories/dynamodb.py @@ -6,7 +6,6 @@ from langchain.schema import ( BaseChatMessageHistory, BaseMessage, HumanMessage, - _message_to_dict, messages_from_dict, messages_to_dict, ) @@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): from botocore.exceptions import ClientError messages = messages_to_dict(self.messages) - _message = _message_to_dict(message) + _message = message.dict() messages.append(_message) try: diff --git a/langchain/memory/chat_message_histories/mongodb.py b/langchain/memory/chat_message_histories/mongodb.py index 7995609b4d0..268a3fd36ac 100644 --- a/langchain/memory/chat_message_histories/mongodb.py +++ b/langchain/memory/chat_message_histories/mongodb.py @@ -7,7 +7,6 @@ from langchain.schema import ( BaseChatMessageHistory, BaseMessage, HumanMessage, - _message_to_dict, messages_from_dict, ) @@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory): self.collection.insert_one( { "SessionId": self.session_id, - "History": json.dumps(_message_to_dict(message)), + "History": json.dumps(message.dict()), } ) except errors.WriteError as err: diff --git a/langchain/memory/chat_message_histories/postgres.py b/langchain/memory/chat_message_histories/postgres.py index ddca84443ce..44699bdfee8 100644 --- a/langchain/memory/chat_message_histories/postgres.py +++ b/langchain/memory/chat_message_histories/postgres.py @@ -7,7 +7,6 @@ from langchain.schema import ( BaseChatMessageHistory, BaseMessage, HumanMessage, - _message_to_dict, messages_from_dict, ) @@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory): query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format( sql.Identifier(self.table_name) ) - self.cursor.execute( - query, (self.session_id, json.dumps(_message_to_dict(message))) - ) + self.cursor.execute(query, (self.session_id, json.dumps(message.dict()))) self.connection.commit() def clear(self) -> None: diff --git a/langchain/memory/chat_message_histories/redis.py b/langchain/memory/chat_message_histories/redis.py index dad0c303633..9e148b3b4e6 100644 --- a/langchain/memory/chat_message_histories/redis.py +++ b/langchain/memory/chat_message_histories/redis.py @@ -7,7 +7,6 @@ from langchain.schema import ( BaseChatMessageHistory, BaseMessage, HumanMessage, - _message_to_dict, messages_from_dict, ) @@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory): def append(self, message: BaseMessage) -> None: """Append the message to the record in Redis""" - self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message))) + self.redis_client.lpush(self.key, json.dumps(message.dict())) if self.ttl: self.redis_client.expire(self.key, self.ttl) diff --git a/langchain/memory/chat_message_histories/sql.py b/langchain/memory/chat_message_histories/sql.py index e3770133b2e..54a53e4a9c0 100644 --- a/langchain/memory/chat_message_histories/sql.py +++ b/langchain/memory/chat_message_histories/sql.py @@ -10,7 +10,6 @@ from langchain.schema import ( BaseChatMessageHistory, BaseMessage, HumanMessage, - _message_to_dict, messages_from_dict, ) @@ -69,7 +68,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory): def append(self, message: BaseMessage) -> None: """Append the message to the record in db""" with self.Session() as session: - jsonstr = json.dumps(_message_to_dict(message)) + jsonstr = json.dumps(message.dict()) session.add(self.Message(session_id=self.session_id, message=jsonstr)) session.commit() diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index ca665866d51..cefdeb7762d 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -2,11 +2,11 @@ from __future__ import annotations import json -import yaml from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union +import yaml from pydantic import BaseModel, Field from langchain.memory.buffer import get_buffer_string @@ -87,6 +87,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC): def input_variables(self) -> List[str]: return self.prompt.input_variables + @property + @abstractmethod + def _type(self) -> str: + """The type of MessagePromptTemplate.""" + + def dict(self, *args, **kwargs): + result = super().dict(*args, **kwargs) + result["_type"] = self._type + return result + class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): role: str @@ -97,14 +107,22 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate): content=text, role=self.role, additional_kwargs=self.additional_kwargs ) + @property + def _type(self) -> str: + """The type of MessagePromptTemplate.""" + return "chat-message-prompt-template" + class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): - role: str = "human" - def format(self, **kwargs: Any) -> BaseMessage: text = self.prompt.format(**kwargs) return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) + @property + def _type(self) -> str: + """The type of MessagePromptTemplate.""" + return "human-message-prompt-template" + class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): role: str = "ai" @@ -113,6 +131,11 @@ class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): text = self.prompt.format(**kwargs) return AIMessage(content=text, additional_kwargs=self.additional_kwargs) + @property + def _type(self) -> str: + """The type of MessagePromptTemplate.""" + return "ai-message-prompt-template" + class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): role: str = "system" @@ -121,6 +144,11 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate): text = self.prompt.format(**kwargs) return SystemMessage(content=text, additional_kwargs=self.additional_kwargs) + @property + def _type(self) -> str: + """The type of MessagePromptTemplate.""" + return "system-message-prompt-template" + class ChatPromptValue(PromptValue): messages: List[BaseMessage] @@ -149,15 +177,7 @@ class BaseChatPromptTemplate(BasePromptTemplate, ABC): class ChatPromptTemplate(BaseChatPromptTemplate, ABC): input_variables: List[str] - messages: List[ - Union[ - BaseMessagePromptTemplate, - BaseMessage, - AIMessagePromptTemplate, - SystemMessagePromptTemplate, - HumanMessagePromptTemplate, - ] - ] + messages: List[Union[BaseMessagePromptTemplate, BaseMessage]] @classmethod def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate: @@ -224,19 +244,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC): @property def _prompt_type(self) -> str: """Return the prompt type key.""" - return "chatPrompt" + return "chat_prompt" def dict(self, **kwargs: Any) -> Dict: - prompt_dict = json.loads(self.json()) - - prompt_dict["_type"] = self._prompt_type - for i, message in enumerate(self.messages): - if isinstance(message, SystemMessagePromptTemplate): - prompt_dict["messages"][i]["prompt"]["role"] = "system" - elif isinstance(message, HumanMessagePromptTemplate): - prompt_dict["messages"][i]["prompt"]["role"] = "human" - elif isinstance(message, AIMessagePromptTemplate): - prompt_dict["messages"][i]["prompt"]["role"] = "ai" + prompt_dict = super().dict(**kwargs) + # for i, message in enumerate(self.messages): + # if isinstance(message, SystemMessagePromptTemplate): + # prompt_dict["messages"][i]["prompt"]["role"] = "system" + # elif isinstance(message, HumanMessagePromptTemplate): + # prompt_dict["messages"][i]["prompt"]["role"] = "human" + # elif isinstance(message, AIMessagePromptTemplate): + # prompt_dict["messages"][i]["prompt"]["role"] = "ai" return prompt_dict diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 0c0f20f808f..0260a526a97 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -17,6 +17,7 @@ from langchain.prompts.chat import ( ) from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.schema import message_from_dict from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" @@ -129,14 +130,21 @@ def _load_chat_prompt(config: dict) -> ChatPromptTemplate: messages = [] for message in config["messages"]: - role = message["prompt"].get("role") - template = message["prompt"]["template"] - if role == "human": - messages.append(HumanMessagePromptTemplate.from_template(template)) - elif role == "ai": - messages.append(AIMessagePromptTemplate.from_template(template)) + _type = message.pop("_type") + if _type == "human-message-prompt-template": + prompt = load_prompt_from_config(message.pop("prompt")) + _message = HumanMessagePromptTemplate(**{"prompt": prompt, **message}) + elif _type == "ai-message-prompt-template": + prompt = load_prompt_from_config(message.pop("prompt")) + _message = AIMessagePromptTemplate(**{"prompt": prompt, **message}) + elif _type == "system-message-prompt-template": + prompt = load_prompt_from_config(message.pop("prompt")) + _message = SystemMessagePromptTemplate(**{"prompt": prompt, **message}) + elif _type == "base-message": + _message = message_from_dict(message) else: # role == system - messages.append(SystemMessagePromptTemplate.from_template(template)) + raise ValueError + messages.append(_message) return ChatPromptTemplate.from_messages(messages) @@ -185,7 +193,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: type_to_loader_dict = { "prompt": _load_prompt, - "chatPrompt": _load_chat_prompt, + "chat_prompt": _load_chat_prompt, "few_shot": _load_few_shot_prompt, # "few_shot_with_templates": _load_few_shot_with_templates_prompt, } diff --git a/langchain/schema.py b/langchain/schema.py index 21552b9b38d..f55ec61a238 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -75,6 +75,13 @@ class BaseMessage(BaseModel): def type(self) -> str: """Type of the message, used for serialization.""" + def dict(self, *args, **kwargs): + return { + "type": self.type, + "data": super().dict(*args, **kwargs), + "_type": "base-message", + } + class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" @@ -118,15 +125,11 @@ class ChatMessage(BaseMessage): return "chat" -def _message_to_dict(message: BaseMessage) -> dict: - return {"type": message.type, "data": message.dict()} - - def messages_to_dict(messages: List[BaseMessage]) -> List[dict]: - return [_message_to_dict(m) for m in messages] + return [m.dict() for m in messages] -def _message_from_dict(message: dict) -> BaseMessage: +def message_from_dict(message: dict) -> BaseMessage: _type = message["type"] if _type == "human": return HumanMessage(**message["data"]) @@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage: def messages_from_dict(messages: List[dict]) -> List[BaseMessage]: - return [_message_from_dict(m) for m in messages] + return [message_from_dict(m) for m in messages] class ChatGeneration(Generation): diff --git a/tests/unit_tests/prompts/test_loading.py b/tests/unit_tests/prompts/test_loading.py index 717eecb8311..e062cc436fe 100644 --- a/tests/unit_tests/prompts/test_loading.py +++ b/tests/unit_tests/prompts/test_loading.py @@ -5,9 +5,16 @@ from contextlib import contextmanager from pathlib import Path from typing import Iterator +from langchain.prompts.chat import ( + AIMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + SystemMessagePromptTemplate, +) from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.loading import load_prompt from langchain.prompts.prompt import PromptTemplate +from langchain.schema import AIMessage, HumanMessage, SystemMessage @contextmanager @@ -71,6 +78,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None: assert loaded_prompt == few_shot_prompt +def test_saving_chat_loading_round_trip(tmp_path: Path) -> None: + """Test equality when saving and loading a chat prompt.""" + message_list = [ + [HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")], + [AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")], + [ + SystemMessage(content="hi"), + SystemMessagePromptTemplate.from_template("{foo}"), + ], + ] + for messages in message_list: + simple_prompt = ChatPromptTemplate.from_messages(messages) + simple_prompt.save(file_path=tmp_path / "prompt.yaml") + loaded_prompt = load_prompt(tmp_path / "prompt.yaml") + assert loaded_prompt == simple_prompt + + def test_loading_with_template_as_file() -> None: """Test loading when the template is a file.""" with change_directory():