Extend DynamoDBChatMessageHistory to support composite keys (#9896)

- Description: Adds two optional parameters to the
DynamoDBChatMessageHistory class to enable users to pass in a name for
their PrimaryKey, or a Key object itself to enable the use of composite
keys, a common DynamoDB paradigm.
  
[AWS DynamoDB Key
docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/)
  
  - Issue: N/A
  - Dependencies: N/A
  - Twitter handle: N/A

---------

Co-authored-by: Josh White <josh@ctrlstack.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Josh White
2023-09-03 17:05:16 -05:00
committed by GitHub
parent 872d829201
commit bc8cceebf7
2 changed files with 126 additions and 141 deletions

View File

@@ -1,5 +1,5 @@
import logging
from typing import List, Optional
from typing import Dict, List, Optional
from langchain.schema import (
BaseChatMessageHistory,
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
"""Chat message history that stores history in AWS DynamoDB.
This class expects that a DynamoDB table with name `table_name`
and a partition Key of `SessionId` is present.
This class expects that a DynamoDB table exists with name `table_name`
Args:
table_name: name of the DynamoDB table
@@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
is optional and useful for test purposes, like using Localstack.
If you plan to use AWS cloud service, you normally don't have to
worry about setting the endpoint_url.
primary_key_name: name of the primary key of the DynamoDB table. This argument
is optional, defaulting to "SessionId".
key: an optional dictionary with a custom primary and secondary key.
This argument is optional, but useful when using composite dynamodb keys, or
isolating records based off of application details such as a user id.
This may also contain global and local secondary index keys.
"""
def __init__(
self, table_name: str, session_id: str, endpoint_url: Optional[str] = None
self,
table_name: str,
session_id: str,
endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None,
):
import boto3
@@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
client = boto3.resource("dynamodb")
self.table = client.Table(table_name)
self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id}
@property
def messages(self) -> List[BaseMessage]: # type: ignore
@@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
response = None
try:
response = self.table.get_item(Key={"SessionId": self.session_id})
response = self.table.get_item(Key=self.key)
except ClientError as error:
if error.response["Error"]["Code"] == "ResourceNotFoundException":
logger.warning("No record found with session id: %s", self.session_id)
@@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
messages.append(_message)
try:
self.table.put_item(
Item={"SessionId": self.session_id, "History": messages}
)
self.table.put_item(Item={**self.key, "History": messages})
except ClientError as err:
logger.error(err)
@@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
from botocore.exceptions import ClientError
try:
self.table.delete_item(Key={"SessionId": self.session_id})
self.table.delete_item(self.key)
except ClientError as err:
logger.error(err)