mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-19 03:44:40 +00:00
Compare commits
11 Commits
langchain-
...
harrison/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c353ba0e53 | ||
|
|
6629e87c10 | ||
|
|
bd04d408d5 | ||
|
|
829847decc | ||
|
|
e8ac6c5134 | ||
|
|
0bde7b73d1 | ||
|
|
9de19bbe9d | ||
|
|
5811e0c6eb | ||
|
|
dbd8588148 | ||
|
|
97cb07b910 | ||
|
|
8aa7e750be |
@@ -7,7 +7,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,7 +159,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
|
|||||||
self.session.execute(
|
self.session.execute(
|
||||||
"""INSERT INTO message_store
|
"""INSERT INTO message_store
|
||||||
(id, session_id, history) VALUES (%s, %s, %s);""",
|
(id, session_id, history) VALUES (%s, %s, %s);""",
|
||||||
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
|
(uuid.uuid4(), self.session_id, json.dumps(message.dict())),
|
||||||
)
|
)
|
||||||
except (Unavailable, WriteTimeout, WriteFailure) as error:
|
except (Unavailable, WriteTimeout, WriteFailure) as error:
|
||||||
logger.error("Unable to write chat history messages to cassandra")
|
logger.error("Unable to write chat history messages to cassandra")
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
messages = messages_to_dict(self.messages)
|
messages = messages_to_dict(self.messages)
|
||||||
_message = _message_to_dict(message)
|
_message = message.dict()
|
||||||
messages.append(_message)
|
messages.append(_message)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
self.collection.insert_one(
|
self.collection.insert_one(
|
||||||
{
|
{
|
||||||
"SessionId": self.session_id,
|
"SessionId": self.session_id,
|
||||||
"History": json.dumps(_message_to_dict(message)),
|
"History": json.dumps(message.dict()),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except errors.WriteError as err:
|
except errors.WriteError as err:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
|||||||
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
|
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
|
||||||
sql.Identifier(self.table_name)
|
sql.Identifier(self.table_name)
|
||||||
)
|
)
|
||||||
self.cursor.execute(
|
self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
|
||||||
query, (self.session_id, json.dumps(_message_to_dict(message)))
|
|
||||||
)
|
|
||||||
self.connection.commit()
|
self.connection.commit()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
|||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
def append(self, message: BaseMessage) -> None:
|
||||||
"""Append the message to the record in Redis"""
|
"""Append the message to the record in Redis"""
|
||||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
self.redis_client.lpush(self.key, json.dumps(message.dict()))
|
||||||
if self.ttl:
|
if self.ttl:
|
||||||
self.redis_client.expire(self.key, self.ttl)
|
self.redis_client.expire(self.key, self.ttl)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from langchain.schema import (
|
|||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
_message_to_dict,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -70,7 +69,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
|||||||
def append(self, message: BaseMessage) -> None:
|
def append(self, message: BaseMessage) -> None:
|
||||||
"""Append the message to the record in db"""
|
"""Append the message to the record in db"""
|
||||||
with self.Session() as session:
|
with self.Session() as session:
|
||||||
jsonstr = json.dumps(_message_to_dict(message))
|
jsonstr = json.dumps(message.dict())
|
||||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Chat prompt template."""
|
"""Chat prompt template."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from langchain.memory.buffer import get_buffer_string
|
from langchain.memory.buffer import get_buffer_string
|
||||||
@@ -95,6 +97,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
|||||||
def input_variables(self) -> List[str]:
|
def input_variables(self) -> List[str]:
|
||||||
return self.prompt.input_variables
|
return self.prompt.input_variables
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""The type of MessagePromptTemplate."""
|
||||||
|
|
||||||
|
def dict(self, *args: Any, **kwargs: Any) -> dict:
|
||||||
|
result = super().dict(*args, **kwargs)
|
||||||
|
result["_type"] = self._type
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||||
role: str
|
role: str
|
||||||
@@ -105,24 +117,48 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
|||||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""The type of MessagePromptTemplate."""
|
||||||
|
return "chat-message-prompt-template"
|
||||||
|
|
||||||
|
|
||||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
def format(self, **kwargs: Any) -> BaseMessage:
|
||||||
text = self.prompt.format(**kwargs)
|
text = self.prompt.format(**kwargs)
|
||||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""The type of MessagePromptTemplate."""
|
||||||
|
return "human-message-prompt-template"
|
||||||
|
|
||||||
|
|
||||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||||
|
role: str = "ai"
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
def format(self, **kwargs: Any) -> BaseMessage:
|
||||||
text = self.prompt.format(**kwargs)
|
text = self.prompt.format(**kwargs)
|
||||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""The type of MessagePromptTemplate."""
|
||||||
|
return "ai-message-prompt-template"
|
||||||
|
|
||||||
|
|
||||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||||
|
role: str = "system"
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> BaseMessage:
|
def format(self, **kwargs: Any) -> BaseMessage:
|
||||||
text = self.prompt.format(**kwargs)
|
text = self.prompt.format(**kwargs)
|
||||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _type(self) -> str:
|
||||||
|
"""The type of MessagePromptTemplate."""
|
||||||
|
return "system-message-prompt-template"
|
||||||
|
|
||||||
|
|
||||||
class ChatPromptValue(PromptValue):
|
class ChatPromptValue(PromptValue):
|
||||||
messages: List[BaseMessage]
|
messages: List[BaseMessage]
|
||||||
@@ -217,7 +253,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _prompt_type(self) -> str:
|
def _prompt_type(self) -> str:
|
||||||
raise NotImplementedError
|
"""Return the prompt type key."""
|
||||||
|
return "chat_prompt"
|
||||||
|
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
raise NotImplementedError
|
if isinstance(file_path, str):
|
||||||
|
save_path = Path(file_path)
|
||||||
|
else:
|
||||||
|
save_path = file_path
|
||||||
|
|
||||||
|
directory_path = save_path.parent
|
||||||
|
directory_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Fetch dictionary to save
|
||||||
|
prompt_dict = self.dict()
|
||||||
|
|
||||||
|
if save_path.suffix == ".json":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
json.dump(prompt_dict, f, indent=4)
|
||||||
|
elif save_path.suffix == ".yaml":
|
||||||
|
with open(file_path, "w") as f:
|
||||||
|
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{save_path} must be json or yaml")
|
||||||
|
|||||||
@@ -9,8 +9,16 @@ import yaml
|
|||||||
|
|
||||||
from langchain.output_parsers.regex import RegexParser
|
from langchain.output_parsers.regex import RegexParser
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
from langchain.prompts.base import BasePromptTemplate
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
AIMessagePromptTemplate,
|
||||||
|
BaseMessagePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import BaseMessage, message_from_dict
|
||||||
from langchain.utilities.loading import try_load_from_hub
|
from langchain.utilities.loading import try_load_from_hub
|
||||||
|
|
||||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||||
@@ -114,6 +122,35 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
|||||||
return PromptTemplate(**config)
|
return PromptTemplate(**config)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||||
|
"""Load the prompt template from config."""
|
||||||
|
# Load the template from disk if necessary.
|
||||||
|
config = _load_template("template", config)
|
||||||
|
config = _load_output_parser(config)
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
for message in config["messages"]:
|
||||||
|
_type = message.pop("_type")
|
||||||
|
if _type == "human-message-prompt-template":
|
||||||
|
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||||
|
_message: Union[
|
||||||
|
BaseMessagePromptTemplate, BaseMessage
|
||||||
|
] = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||||
|
elif _type == "ai-message-prompt-template":
|
||||||
|
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||||
|
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||||
|
elif _type == "system-message-prompt-template":
|
||||||
|
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||||
|
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||||
|
elif _type == "base-message":
|
||||||
|
_message = message_from_dict(message)
|
||||||
|
else: # role == system
|
||||||
|
raise ValueError
|
||||||
|
messages.append(_message)
|
||||||
|
|
||||||
|
return ChatPromptTemplate.from_messages(messages)
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||||
if hub_result := try_load_from_hub(
|
if hub_result := try_load_from_hub(
|
||||||
@@ -158,6 +195,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
|||||||
|
|
||||||
type_to_loader_dict = {
|
type_to_loader_dict = {
|
||||||
"prompt": _load_prompt,
|
"prompt": _load_prompt,
|
||||||
|
"chat_prompt": _load_chat_prompt,
|
||||||
"few_shot": _load_few_shot_prompt,
|
"few_shot": _load_few_shot_prompt,
|
||||||
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
|
|||||||
def type(self) -> str:
|
def type(self) -> str:
|
||||||
"""Type of the message, used for serialization."""
|
"""Type of the message, used for serialization."""
|
||||||
|
|
||||||
|
def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": self.type,
|
||||||
|
"data": super().dict(*args, **kwargs),
|
||||||
|
"_type": "base-message",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class HumanMessage(BaseMessage):
|
class HumanMessage(BaseMessage):
|
||||||
"""Type of message that is spoken by the human."""
|
"""Type of message that is spoken by the human."""
|
||||||
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
|
|||||||
return "chat"
|
return "chat"
|
||||||
|
|
||||||
|
|
||||||
def _message_to_dict(message: BaseMessage) -> dict:
|
|
||||||
return {"type": message.type, "data": message.dict()}
|
|
||||||
|
|
||||||
|
|
||||||
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
||||||
return [_message_to_dict(m) for m in messages]
|
return [m.dict() for m in messages]
|
||||||
|
|
||||||
|
|
||||||
def _message_from_dict(message: dict) -> BaseMessage:
|
def message_from_dict(message: dict) -> BaseMessage:
|
||||||
_type = message["type"]
|
_type = message["type"]
|
||||||
if _type == "human":
|
if _type == "human":
|
||||||
return HumanMessage(**message["data"])
|
return HumanMessage(**message["data"])
|
||||||
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
|||||||
|
|
||||||
|
|
||||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||||
return [_message_from_dict(m) for m in messages]
|
return [message_from_dict(m) for m in messages]
|
||||||
|
|
||||||
|
|
||||||
class ChatGeneration(Generation):
|
class ChatGeneration(Generation):
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from langchain.memory import ConversationBufferMemory
|
|||||||
from langchain.memory.chat_message_histories.cassandra import (
|
from langchain.memory.chat_message_histories.cassandra import (
|
||||||
CassandraChatMessageHistory,
|
CassandraChatMessageHistory,
|
||||||
)
|
)
|
||||||
from langchain.schema import _message_to_dict
|
|
||||||
|
|
||||||
# Replace these with your cassandra contact points
|
# Replace these with your cassandra contact points
|
||||||
contact_points = (
|
contact_points = (
|
||||||
@@ -31,7 +30,7 @@ def test_memory_with_message_store() -> None:
|
|||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
# get the message history from the memory store and turn it into a json
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
assert "This is me, the AI" in messages_json
|
||||||
assert "This is me, the human" in messages_json
|
assert "This is me, the human" in messages_json
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import os
|
|||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
|
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
|
||||||
from langchain.schema import _message_to_dict
|
|
||||||
|
|
||||||
# Replace these with your Azure Cosmos DB endpoint and key
|
# Replace these with your Azure Cosmos DB endpoint and key
|
||||||
endpoint = os.environ["COSMOS_DB_ENDPOINT"]
|
endpoint = os.environ["COSMOS_DB_ENDPOINT"]
|
||||||
@@ -33,7 +32,7 @@ def test_memory_with_message_store() -> None:
|
|||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
# get the message history from the memory store and turn it into a json
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
assert "This is me, the AI" in messages_json
|
||||||
assert "This is me, the human" in messages_json
|
assert "This is me, the human" in messages_json
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import json
|
|||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
|
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
|
||||||
from langchain.schema import _message_to_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_with_message_store() -> None:
|
def test_memory_with_message_store() -> None:
|
||||||
@@ -32,7 +31,7 @@ def test_memory_with_message_store() -> None:
|
|||||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||||
)
|
)
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
assert "This is me, the AI" in messages_json
|
||||||
assert "This is me, the human" in messages_json
|
assert "This is me, the human" in messages_json
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import os
|
|||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
|
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
|
||||||
from langchain.schema import _message_to_dict
|
|
||||||
|
|
||||||
# Replace these with your mongodb connection string
|
# Replace these with your mongodb connection string
|
||||||
connection_string = os.environ["MONGODB_CONNECTION_STRING"]
|
connection_string = os.environ["MONGODB_CONNECTION_STRING"]
|
||||||
@@ -25,7 +24,7 @@ def test_memory_with_message_store() -> None:
|
|||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
# get the message history from the memory store and turn it into a json
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
assert "This is me, the AI" in messages_json
|
||||||
assert "This is me, the human" in messages_json
|
assert "This is me, the human" in messages_json
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import json
|
|||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
from langchain.memory.chat_message_histories import RedisChatMessageHistory
|
from langchain.memory.chat_message_histories import RedisChatMessageHistory
|
||||||
from langchain.schema import _message_to_dict
|
|
||||||
|
|
||||||
|
|
||||||
def test_memory_with_message_store() -> None:
|
def test_memory_with_message_store() -> None:
|
||||||
@@ -21,7 +20,7 @@ def test_memory_with_message_store() -> None:
|
|||||||
|
|
||||||
# get the message history from the memory store and turn it into a json
|
# get the message history from the memory store and turn it into a json
|
||||||
messages = memory.chat_memory.messages
|
messages = memory.chat_memory.messages
|
||||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||||
|
|
||||||
assert "This is me, the AI" in messages_json
|
assert "This is me, the AI" in messages_json
|
||||||
assert "This is me, the human" in messages_json
|
assert "This is me, the human" in messages_json
|
||||||
|
|||||||
@@ -3,12 +3,19 @@
|
|||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterator
|
from typing import Iterator, List, Sequence
|
||||||
|
|
||||||
from langchain.output_parsers import RegexParser
|
from langchain.output_parsers import RegexParser
|
||||||
|
from langchain.prompts.chat import (
|
||||||
|
AIMessagePromptTemplate,
|
||||||
|
ChatPromptTemplate,
|
||||||
|
HumanMessagePromptTemplate,
|
||||||
|
SystemMessagePromptTemplate,
|
||||||
|
)
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.loading import load_prompt
|
from langchain.prompts.loading import load_prompt
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
@@ -72,6 +79,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
|||||||
assert loaded_prompt == few_shot_prompt
|
assert loaded_prompt == few_shot_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
|
||||||
|
"""Test equality when saving and loading a chat prompt."""
|
||||||
|
message_list: List[Sequence] = [
|
||||||
|
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
|
||||||
|
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
|
||||||
|
[
|
||||||
|
SystemMessage(content="hi"),
|
||||||
|
SystemMessagePromptTemplate.from_template("{foo}"),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
for messages in message_list:
|
||||||
|
simple_prompt = ChatPromptTemplate.from_messages(messages)
|
||||||
|
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||||
|
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||||
|
assert loaded_prompt == simple_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_loading_with_template_as_file() -> None:
|
def test_loading_with_template_as_file() -> None:
|
||||||
"""Test loading when the template is a file."""
|
"""Test loading when the template is a file."""
|
||||||
with change_directory():
|
with change_directory():
|
||||||
|
|||||||
Reference in New Issue
Block a user