mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
Implemented appending arbitrary messages (#5293)
# Implemented appending arbitrary messages to the base chat message history, the in-memory and cosmos ones. <!-- Thank you for contributing to LangChain! Your PR will appear in our next release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> As discussed this is the alternative way instead of #4480, with a add_message method added that takes a BaseMessage as input, so that the user can control what is in the base message like kwargs. <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting <!-- If you're adding a new integration, include an integration test and an example notebook showing its use! --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
d6fb25c439
commit
ccb6238de1
@ -3,10 +3,8 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -143,13 +141,7 @@ class CassandraChatMessageHistory(BaseChatMessageHistory):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in Cassandra"""
|
"""Append the message to the record in Cassandra"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -6,10 +6,8 @@ from types import TracebackType
|
|||||||
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
from typing import TYPE_CHECKING, Any, List, Optional, Type
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
@ -145,18 +143,13 @@ class CosmosDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
if "messages" in item and len(item["messages"]) > 0:
|
if "messages" in item and len(item["messages"]) > 0:
|
||||||
self.messages = messages_from_dict(item["messages"])
|
self.messages = messages_from_dict(item["messages"])
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
"""Add a user message to the memory."""
|
"""Add a self-created message to the store"""
|
||||||
self.upsert_messages(HumanMessage(content=message))
|
self.messages.append(message)
|
||||||
|
self.upsert_messages()
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
def upsert_messages(self) -> None:
|
||||||
"""Add a AI message to the memory."""
|
|
||||||
self.upsert_messages(AIMessage(content=message))
|
|
||||||
|
|
||||||
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
|
|
||||||
"""Update the cosmosdb item."""
|
"""Update the cosmosdb item."""
|
||||||
if new_message:
|
|
||||||
self.messages.append(new_message)
|
|
||||||
if not self._container:
|
if not self._container:
|
||||||
raise ValueError("Container not initialized")
|
raise ValueError("Container not initialized")
|
||||||
self._container.upsert_item(
|
self._container.upsert_item(
|
||||||
|
@ -2,10 +2,8 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
@ -53,13 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in DynamoDB"""
|
"""Append the message to the record in DynamoDB"""
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
@ -4,10 +4,8 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
@ -36,13 +34,7 @@ class FileChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in the local file"""
|
"""Append the message to the record in the local file"""
|
||||||
messages = messages_to_dict(self.messages)
|
messages = messages_to_dict(self.messages)
|
||||||
messages.append(messages_to_dict([message])[0])
|
messages.append(messages_to_dict([message])[0])
|
||||||
|
@ -5,10 +5,8 @@ import logging
|
|||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
messages_to_dict,
|
messages_to_dict,
|
||||||
)
|
)
|
||||||
@ -81,18 +79,12 @@ class FirestoreChatMessageHistory(BaseChatMessageHistory):
|
|||||||
if "messages" in data and len(data["messages"]) > 0:
|
if "messages" in data and len(data["messages"]) > 0:
|
||||||
self.messages = messages_from_dict(data["messages"])
|
self.messages = messages_from_dict(data["messages"])
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
"""Add a user message to the memory."""
|
self.messages.append(message)
|
||||||
self.upsert_messages(HumanMessage(content=message))
|
self.upsert_messages()
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
"""Add a AI message to the memory."""
|
|
||||||
self.upsert_messages(AIMessage(content=message))
|
|
||||||
|
|
||||||
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
|
def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
|
||||||
"""Update the Firestore document."""
|
"""Update the Firestore document."""
|
||||||
if new_message:
|
|
||||||
self.messages.append(new_message)
|
|
||||||
if not self._document:
|
if not self._document:
|
||||||
raise ValueError("Document not initialized")
|
raise ValueError("Document not initialized")
|
||||||
self._document.set(
|
self._document.set(
|
||||||
|
@ -3,21 +3,17 @@ from typing import List
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
|
||||||
messages: List[BaseMessage] = []
|
messages: List[BaseMessage] = []
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.messages.append(HumanMessage(content=message))
|
"""Add a self-created message to the store"""
|
||||||
|
self.messages.append(message)
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.messages.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
@ -5,10 +5,8 @@ from datetime import timedelta
|
|||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -143,23 +141,7 @@ class MomentoChatMessageHistory(BaseChatMessageHistory):
|
|||||||
else:
|
else:
|
||||||
raise Exception(f"Unexpected response: {fetch_response}")
|
raise Exception(f"Unexpected response: {fetch_response}")
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
"""Store a user message in the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): The message to store.
|
|
||||||
"""
|
|
||||||
self.__add_message(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
"""Store an AI message in the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message (str): The message to store.
|
|
||||||
"""
|
|
||||||
self.__add_message(AIMessage(content=message))
|
|
||||||
|
|
||||||
def __add_message(self, message: BaseMessage) -> None:
|
|
||||||
"""Store a message in the cache.
|
"""Store a message in the cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -3,10 +3,8 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -68,13 +66,7 @@ class MongoDBChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in MongoDB"""
|
"""Append the message to the record in MongoDB"""
|
||||||
from pymongo import errors
|
from pymongo import errors
|
||||||
|
|
||||||
|
@ -3,10 +3,8 @@ import logging
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -55,13 +53,7 @@ class PostgresChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in PostgreSQL"""
|
"""Append the message to the record in PostgreSQL"""
|
||||||
from psycopg import sql
|
from psycopg import sql
|
||||||
|
|
||||||
|
@ -3,10 +3,8 @@ import logging
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -52,13 +50,7 @@ class RedisChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in Redis"""
|
"""Append the message to the record in Redis"""
|
||||||
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
|
||||||
if self.ttl:
|
if self.ttl:
|
||||||
|
@ -7,10 +7,8 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AIMessage,
|
|
||||||
BaseChatMessageHistory,
|
BaseChatMessageHistory,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
HumanMessage,
|
|
||||||
_message_to_dict,
|
_message_to_dict,
|
||||||
messages_from_dict,
|
messages_from_dict,
|
||||||
)
|
)
|
||||||
@ -61,13 +59,7 @@ class SQLChatMessageHistory(BaseChatMessageHistory):
|
|||||||
messages = messages_from_dict(items)
|
messages = messages_from_dict(items)
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the record in db"""
|
"""Append the message to the record in db"""
|
||||||
with self.Session() as session:
|
with self.Session() as session:
|
||||||
jsonstr = json.dumps(_message_to_dict(message))
|
jsonstr = json.dumps(_message_to_dict(message))
|
||||||
|
@ -116,13 +116,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
|
|||||||
return None
|
return None
|
||||||
return zep_memory
|
return zep_memory
|
||||||
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
self.append(HumanMessage(content=message))
|
|
||||||
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
|
||||||
self.append(AIMessage(content=message))
|
|
||||||
|
|
||||||
def append(self, message: BaseMessage) -> None:
|
|
||||||
"""Append the message to the Zep memory history"""
|
"""Append the message to the Zep memory history"""
|
||||||
from zep_python import Memory, Message
|
from zep_python import Memory, Message
|
||||||
|
|
||||||
|
@ -234,18 +234,11 @@ class BaseChatMessageHistory(ABC):
|
|||||||
messages = json.loads(f.read())
|
messages = json.loads(f.read())
|
||||||
return messages_from_dict(messages)
|
return messages_from_dict(messages)
|
||||||
|
|
||||||
def add_user_message(self, message: str):
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
message_ = HumanMessage(content=message)
|
messages = self.messages.append(_message_to_dict(message))
|
||||||
messages = self.messages.append(_message_to_dict(_message))
|
|
||||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
json.dump(f, messages)
|
json.dump(f, messages)
|
||||||
|
|
||||||
def add_ai_message(self, message: str):
|
|
||||||
message_ = AIMessage(content=message)
|
|
||||||
messages = self.messages.append(_message_to_dict(_message))
|
|
||||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
|
||||||
json.dump(f, messages)
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
with open(os.path.join(storage_path, session_id), 'w') as f:
|
with open(os.path.join(storage_path, session_id), 'w') as f:
|
||||||
f.write("[]")
|
f.write("[]")
|
||||||
@ -253,13 +246,17 @@ class BaseChatMessageHistory(ABC):
|
|||||||
|
|
||||||
messages: List[BaseMessage]
|
messages: List[BaseMessage]
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_user_message(self, message: str) -> None:
|
def add_user_message(self, message: str) -> None:
|
||||||
"""Add a user message to the store"""
|
"""Add a user message to the store"""
|
||||||
|
self.add_message(HumanMessage(content=message))
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def add_ai_message(self, message: str) -> None:
|
def add_ai_message(self, message: str) -> None:
|
||||||
"""Add an AI message to the store"""
|
"""Add an AI message to the store"""
|
||||||
|
self.add_message(AIMessage(content=message))
|
||||||
|
|
||||||
|
def add_message(self, message: BaseMessage) -> None:
|
||||||
|
"""Add a self-created message to the store"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
|
@ -60,7 +60,7 @@ def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory)
|
|||||||
|
|
||||||
@pytest.mark.requires("zep_python")
|
@pytest.mark.requires("zep_python")
|
||||||
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
|
||||||
zep_chat.append(AIMessage(content="test message"))
|
zep_chat.add_message(AIMessage(content="test message"))
|
||||||
zep_chat.zep_client.add_memory.assert_called_once() # type: ignore
|
zep_chat.zep_client.add_memory.assert_called_once() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user