mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 05:52:15 +00:00
Extend DynamoDBChatMessageHistory to support composite keys (#9896)
- Description: Adds two optional parameters to the DynamoDBChatMessageHistory class to enable users to pass in a name for their PrimaryKey, or a Key object itself to enable the use of composite keys, a common DynamoDB paradigm. [AWS DynamoDB Key docs](https://aws.amazon.com/blogs/database/choosing-the-right-dynamodb-partition-key/) - Issue: N/A - Dependencies: N/A - Twitter handle: N/A --------- Co-authored-by: Josh White <josh@ctrlstack.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain.schema import (
|
||||
BaseChatMessageHistory,
|
||||
@@ -17,8 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
"""Chat message history that stores history in AWS DynamoDB.
|
||||
|
||||
This class expects that a DynamoDB table with name `table_name`
|
||||
and a partition Key of `SessionId` is present.
|
||||
This class expects that a DynamoDB table exists with name `table_name`
|
||||
|
||||
Args:
|
||||
table_name: name of the DynamoDB table
|
||||
@@ -28,10 +27,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
is optional and useful for test purposes, like using Localstack.
|
||||
If you plan to use AWS cloud service, you normally don't have to
|
||||
worry about setting the endpoint_url.
|
||||
primary_key_name: name of the primary key of the DynamoDB table. This argument
|
||||
is optional, defaulting to "SessionId".
|
||||
key: an optional dictionary with a custom primary and secondary key.
|
||||
This argument is optional, but useful when using composite dynamodb keys, or
|
||||
isolating records based off of application details such as a user id.
|
||||
This may also contain global and local secondary index keys.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, table_name: str, session_id: str, endpoint_url: Optional[str] = None
|
||||
self,
|
||||
table_name: str,
|
||||
session_id: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
primary_key_name: str = "SessionId",
|
||||
key: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
import boto3
|
||||
|
||||
@@ -41,6 +51,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
client = boto3.resource("dynamodb")
|
||||
self.table = client.Table(table_name)
|
||||
self.session_id = session_id
|
||||
self.key: Dict = key or {primary_key_name: session_id}
|
||||
|
||||
@property
|
||||
def messages(self) -> List[BaseMessage]: # type: ignore
|
||||
@@ -49,7 +60,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
|
||||
response = None
|
||||
try:
|
||||
response = self.table.get_item(Key={"SessionId": self.session_id})
|
||||
response = self.table.get_item(Key=self.key)
|
||||
except ClientError as error:
|
||||
if error.response["Error"]["Code"] == "ResourceNotFoundException":
|
||||
logger.warning("No record found with session id: %s", self.session_id)
|
||||
@@ -73,9 +84,7 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
messages.append(_message)
|
||||
|
||||
try:
|
||||
self.table.put_item(
|
||||
Item={"SessionId": self.session_id, "History": messages}
|
||||
)
|
||||
self.table.put_item(Item={**self.key, "History": messages})
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
|
||||
@@ -84,6 +93,6 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
self.table.delete_item(Key={"SessionId": self.session_id})
|
||||
self.table.delete_item(self.key)
|
||||
except ClientError as err:
|
||||
logger.error(err)
|
||||
|
Reference in New Issue
Block a user