mirror of
https://github.com/hwchase17/langchain.git
synced 2026-01-23 05:09:12 +00:00
Compare commits
11 Commits
langchain-
...
harrison/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c353ba0e53 | ||
|
|
6629e87c10 | ||
|
|
bd04d408d5 | ||
|
|
829847decc | ||
|
|
e8ac6c5134 | ||
|
|
0bde7b73d1 | ||
|
|
9de19bbe9d | ||
|
|
5811e0c6eb | ||
|
|
dbd8588148 | ||
|
|
97cb07b910 | ||
|
|
8aa7e750be |
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -160,7 +159,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
|
||||
self.session.execute(
|
||||
"""INSERT INTO message_store
|
||||
(id, session_id, history) VALUES (%s, %s, %s);""",
|
||||
(uuid.uuid4(), self.session_id, json.dumps(_message_to_dict(message))),
|
||||
(uuid.uuid4(), self.session_id, json.dumps(message.dict())),
|
||||
)
|
||||
except (Unavailable, WriteTimeout, WriteFailure) as error:
|
||||
logger.error("Unable to write chat history messages to cassandra")
|
||||
|
||||
@@ -6,7 +6,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
@@ -64,7 +63,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
messages = messages_to_dict(self.messages)
|
||||
_message = _message_to_dict(message)
|
||||
_message = message.dict()
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -82,7 +81,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"SessionId": self.session_id,
|
||||
"History": json.dumps(_message_to_dict(message)),
|
||||
"History": json.dumps(message.dict()),
|
||||
}
|
||||
)
|
||||
except errors.WriteError as err:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -68,9 +67,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
||||
query = sql.SQL("INSERT INTO {} (session_id, message) VALUES (%s, %s);").format(
|
||||
sql.Identifier(self.table_name)
|
||||
)
|
||||
self.cursor.execute(
|
||||
query, (self.session_id, json.dumps(_message_to_dict(message)))
|
||||
)
|
||||
self.cursor.execute(query, (self.session_id, json.dumps(message.dict())))
|
||||
self.connection.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@@ -7,7 +7,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -60,7 +59,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
def append(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in Redis"""
|
||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(self.key, json.dumps(message.dict()))
|
||||
if self.ttl:
|
||||
self.redis_client.expire(self.key, self.ttl)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
_message_to_dict,
|
||||
messages_from_dict,
|
||||
)
|
||||
|
||||
@@ -70,7 +69,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
||||
def append(self, message: BaseMessage) -> None:
|
||||
"""Append the message to the record in db"""
|
||||
with self.Session() as session:
|
||||
jsonstr = json.dumps(_message_to_dict(message))
|
||||
jsonstr = json.dumps(message.dict())
|
||||
session.add(self.Message(session_id=self.session_id, message=jsonstr))
|
||||
session.commit()
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Chat prompt template."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, List, Sequence, Tuple, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.memory.buffer import get_buffer_string
|
||||
@@ -95,6 +97,16 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
|
||||
def input_variables(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
|
||||
def dict(self, *args: Any, **kwargs: Any) -> dict:
|
||||
result = super().dict(*args, **kwargs)
|
||||
result["_type"] = self._type
|
||||
return result
|
||||
|
||||
|
||||
class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str
|
||||
@@ -105,24 +117,48 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
content=text, role=self.role, additional_kwargs=self.additional_kwargs
|
||||
)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "chat-message-prompt-template"
|
||||
|
||||
|
||||
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "human-message-prompt-template"
|
||||
|
||||
|
||||
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str = "ai"
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "ai-message-prompt-template"
|
||||
|
||||
|
||||
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
|
||||
role: str = "system"
|
||||
|
||||
def format(self, **kwargs: Any) -> BaseMessage:
|
||||
text = self.prompt.format(**kwargs)
|
||||
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
"""The type of MessagePromptTemplate."""
|
||||
return "system-message-prompt-template"
|
||||
|
||||
|
||||
class ChatPromptValue(PromptValue):
|
||||
messages: List[BaseMessage]
|
||||
@@ -217,7 +253,26 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
@property
|
||||
def _prompt_type(self) -> str:
|
||||
raise NotImplementedError
|
||||
"""Return the prompt type key."""
|
||||
return "chat_prompt"
|
||||
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
raise NotImplementedError
|
||||
if isinstance(file_path, str):
|
||||
save_path = Path(file_path)
|
||||
else:
|
||||
save_path = file_path
|
||||
|
||||
directory_path = save_path.parent
|
||||
directory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Fetch dictionary to save
|
||||
prompt_dict = self.dict()
|
||||
|
||||
if save_path.suffix == ".json":
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(prompt_dict, f, indent=4)
|
||||
elif save_path.suffix == ".yaml":
|
||||
with open(file_path, "w") as f:
|
||||
yaml.dump(prompt_dict, f, default_flow_style=False)
|
||||
else:
|
||||
raise ValueError(f"{save_path} must be json or yaml")
|
||||
|
||||
@@ -9,8 +9,16 @@ import yaml
|
||||
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import BaseMessage, message_from_dict
|
||||
from langchain.utilities.loading import try_load_from_hub
|
||||
|
||||
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
|
||||
@@ -114,6 +122,35 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
||||
return PromptTemplate(**config)
|
||||
|
||||
|
||||
def _load_chat_prompt(config: dict) -> ChatPromptTemplate:
|
||||
"""Load the prompt template from config."""
|
||||
# Load the template from disk if necessary.
|
||||
config = _load_template("template", config)
|
||||
config = _load_output_parser(config)
|
||||
|
||||
messages = []
|
||||
for message in config["messages"]:
|
||||
_type = message.pop("_type")
|
||||
if _type == "human-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message: Union[
|
||||
BaseMessagePromptTemplate, BaseMessage
|
||||
] = HumanMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "ai-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message = AIMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "system-message-prompt-template":
|
||||
prompt = load_prompt_from_config(message.pop("prompt"))
|
||||
_message = SystemMessagePromptTemplate(**{"prompt": prompt, **message})
|
||||
elif _type == "base-message":
|
||||
_message = message_from_dict(message)
|
||||
else: # role == system
|
||||
raise ValueError
|
||||
messages.append(_message)
|
||||
|
||||
return ChatPromptTemplate.from_messages(messages)
|
||||
|
||||
|
||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||
if hub_result := try_load_from_hub(
|
||||
@@ -158,6 +195,7 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
||||
|
||||
type_to_loader_dict = {
|
||||
"prompt": _load_prompt,
|
||||
"chat_prompt": _load_chat_prompt,
|
||||
"few_shot": _load_few_shot_prompt,
|
||||
# "few_shot_with_templates": _load_few_shot_with_templates_prompt,
|
||||
}
|
||||
|
||||
@@ -75,6 +75,13 @@ class BaseMessage(BaseModel):
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
def dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": self.type,
|
||||
"data": super().dict(*args, **kwargs),
|
||||
"_type": "base-message",
|
||||
}
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
@@ -118,15 +125,11 @@ class ChatMessage(BaseMessage):
|
||||
return "chat"
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
|
||||
|
||||
def messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
return [m.dict() for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
def message_from_dict(message: dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
@@ -141,7 +144,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
|
||||
|
||||
def messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
return [message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
class ChatGeneration(Generation):
|
||||
|
||||
@@ -5,7 +5,6 @@ from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories.cassandra import (
|
||||
CassandraChatMessageHistory,
|
||||
)
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
# Replace these with your cassandra contact points
|
||||
contact_points = (
|
||||
@@ -31,7 +30,7 @@ def test_memory_with_message_store() -> None:
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import CosmosDBChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
# Replace these with your Azure Cosmos DB endpoint and key
|
||||
endpoint = os.environ["COSMOS_DB_ENDPOINT"]
|
||||
@@ -33,7 +32,7 @@ def test_memory_with_message_store() -> None:
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import FirestoreChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
@@ -32,7 +31,7 @@ def test_memory_with_message_store() -> None:
|
||||
memory_key="baz", chat_memory=message_history, return_messages=True
|
||||
)
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import MongoDBChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
# Replace these with your mongodb connection string
|
||||
connection_string = os.environ["MONGODB_CONNECTION_STRING"]
|
||||
@@ -25,7 +24,7 @@ def test_memory_with_message_store() -> None:
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.memory.chat_message_histories import RedisChatMessageHistory
|
||||
from langchain.schema import _message_to_dict
|
||||
|
||||
|
||||
def test_memory_with_message_store() -> None:
|
||||
@@ -21,7 +20,7 @@ def test_memory_with_message_store() -> None:
|
||||
|
||||
# get the message history from the memory store and turn it into a json
|
||||
messages = memory.chat_memory.messages
|
||||
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])
|
||||
messages_json = json.dumps([msg.dict() for msg in messages])
|
||||
|
||||
assert "This is me, the AI" in messages_json
|
||||
assert "This is me, the human" in messages_json
|
||||
|
||||
@@ -3,12 +3,19 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from typing import Iterator, List, Sequence
|
||||
|
||||
from langchain.output_parsers import RegexParser
|
||||
from langchain.prompts.chat import (
|
||||
AIMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.loading import load_prompt
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -72,6 +79,23 @@ def test_saving_loading_round_trip(tmp_path: Path) -> None:
|
||||
assert loaded_prompt == few_shot_prompt
|
||||
|
||||
|
||||
def test_saving_chat_loading_round_trip(tmp_path: Path) -> None:
|
||||
"""Test equality when saving and loading a chat prompt."""
|
||||
message_list: List[Sequence] = [
|
||||
[HumanMessage(content="hi"), HumanMessagePromptTemplate.from_template("{foo}")],
|
||||
[AIMessage(content="hi"), AIMessagePromptTemplate.from_template("{foo}")],
|
||||
[
|
||||
SystemMessage(content="hi"),
|
||||
SystemMessagePromptTemplate.from_template("{foo}"),
|
||||
],
|
||||
]
|
||||
for messages in message_list:
|
||||
simple_prompt = ChatPromptTemplate.from_messages(messages)
|
||||
simple_prompt.save(file_path=tmp_path / "prompt.yaml")
|
||||
loaded_prompt = load_prompt(tmp_path / "prompt.yaml")
|
||||
assert loaded_prompt == simple_prompt
|
||||
|
||||
|
||||
def test_loading_with_template_as_file() -> None:
|
||||
"""Test loading when the template is a file."""
|
||||
with change_directory():
|
||||
|
||||
Reference in New Issue
Block a user