Enhancement: add parameter boto3_session for AWS DynamoDB cross account use cases (#10326)

- Description: to allow boto3 assume role for AWS cross account use
cases to read and update the chat history,
  - Issue: use case I faced in my company,
  - Dependencies: no
  - Tag maintainer: @baskaryan ,
  - Twitter handle: @tmin97

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Tze Min 2023-09-08 05:58:28 +08:00 committed by GitHub
parent b1d40b8626
commit 20c742d8a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import logging import logging
from typing import Dict, List, Optional from typing import TYPE_CHECKING, Dict, List, Optional
from langchain.schema import ( from langchain.schema import (
BaseChatMessageHistory, BaseChatMessageHistory,
@ -11,6 +13,9 @@ from langchain.schema.messages import (
messages_to_dict, messages_to_dict,
) )
if TYPE_CHECKING:
from boto3.session import Session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,13 +47,21 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
endpoint_url: Optional[str] = None, endpoint_url: Optional[str] = None,
primary_key_name: str = "SessionId", primary_key_name: str = "SessionId",
key: Optional[Dict[str, str]] = None, key: Optional[Dict[str, str]] = None,
boto3_session: Optional[Session] = None,
): ):
import boto3 if boto3_session:
client = boto3_session.resource("dynamodb")
if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
else: else:
client = boto3.resource("dynamodb") try:
import boto3
except ImportError as e:
raise ImportError(
"Unable to import boto3, please install with `pip install boto3`."
) from e
if endpoint_url:
client = boto3.resource("dynamodb", endpoint_url=endpoint_url)
else:
client = boto3.resource("dynamodb")
self.table = client.Table(table_name) self.table = client.Table(table_name)
self.session_id = session_id self.session_id = session_id
self.key: Dict = key or {primary_key_name: session_id} self.key: Dict = key or {primary_key_name: session_id}
@ -56,7 +69,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
@property @property
def messages(self) -> List[BaseMessage]: # type: ignore def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from DynamoDB""" """Retrieve the messages from DynamoDB"""
from botocore.exceptions import ClientError try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
response = None response = None
try: try:
@ -77,7 +95,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
def add_message(self, message: BaseMessage) -> None: def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB""" """Append the message to the record in DynamoDB"""
from botocore.exceptions import ClientError try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
messages = messages_to_dict(self.messages) messages = messages_to_dict(self.messages)
_message = _message_to_dict(message) _message = _message_to_dict(message)
@ -90,7 +113,12 @@ class DynamoDBChatMessageHistory(BaseChatMessageHistory):
def clear(self) -> None: def clear(self) -> None:
"""Clear session memory from DynamoDB""" """Clear session memory from DynamoDB"""
from botocore.exceptions import ClientError try:
from botocore.exceptions import ClientError
except ImportError as e:
raise ImportError(
"Unable to import botocore, please install with `pip install botocore`."
) from e
try: try:
self.table.delete_item(self.key) self.table.delete_item(self.key)