diff --git a/libs/community/langchain_community/chat_message_histories/dynamodb.py b/libs/community/langchain_community/chat_message_histories/dynamodb.py index ca1586a2b0a..62beb4895e7 100644 --- a/libs/community/langchain_community/chat_message_histories/dynamodb.py +++ b/libs/community/langchain_community/chat_message_histories/dynamodb.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from decimal import Decimal from typing import TYPE_CHECKING, Dict, List, Optional from langchain_core.chat_history import BaseChatMessageHistory @@ -17,6 +18,16 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def convert_messages(item: List) -> List: + if isinstance(item, list): + return [convert_messages(i) for i in item] + elif isinstance(item, dict): + return {k: convert_messages(v) for k, v in item.items()} + elif isinstance(item, float): + return Decimal(str(item)) + return item + + class DynamoDBChatMessageHistory(BaseChatMessageHistory): """Chat message history that stores history in AWS DynamoDB. @@ -47,6 +58,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): limit. If not None then only the latest `history_size` messages are stored. history_messages_key: Key for the chat history where the messages are stored and updated + coerce_float_to_decimal: If True, all float values in the messages will be + converted to Decimal. """ def __init__( @@ -62,6 +75,8 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): ttl_key_name: str = "expireAt", history_size: Optional[int] = None, history_messages_key: Optional[str] = "History", + *, + coerce_float_to_decimal: bool = False, ): if boto3_session: client = boto3_session.resource("dynamodb", endpoint_url=endpoint_url) @@ -83,6 +98,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): self.ttl_key_name = ttl_key_name self.history_size = history_size self.history_messages_key = history_messages_key + self.coerce_float_to_decimal = coerce_float_to_decimal if kms_key_id: try: @@ -159,6 +175,9 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory): _message = message_to_dict(message) messages.append(_message) + if self.coerce_float_to_decimal: + messages = convert_messages(messages) + if self.history_size: messages = messages[-self.history_size :]