diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py index 3d821a09c51..6709a24a8dc 100644 --- a/libs/community/langchain_community/storage/sql.py +++ b/libs/community/langchain_community/storage/sql.py @@ -15,6 +15,7 @@ from typing import ( cast, ) +from langchain_core.documents.base import Document from langchain_core.stores import BaseStore from sqlalchemy import ( LargeBinary, @@ -176,7 +177,6 @@ class SQLStore(BaseStore[str, bytes]): async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: assert isinstance(self.engine, AsyncEngine) - result: Dict[str, bytes] = {} async with self._make_async_session() as session: stmt = select(LangchainKeyValueStores).filter( and_( @@ -188,8 +188,6 @@ class SQLStore(BaseStore[str, bytes]): return [result[0].value for result in results] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: - result = {} - with self._make_sync_session() as session: stmt = select(LangchainKeyValueStores).filter( and_( @@ -200,29 +198,43 @@ class SQLStore(BaseStore[str, bytes]): results = session.execute(stmt).all() return [result[0].value for result in results] - async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None: + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: async with self._make_async_session() as session: await self._amdelete([key for key, _ in key_value_pairs], session) session.add_all( [ - LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) + LangchainKeyValueStores( + namespace=self.namespace, + key=k, + value=self._bytes_or_document(v) + ) for k, v in key_value_pairs ] ) await session.commit() - def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None: + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: values: Dict[str, bytes] = dict(key_value_pairs) with self._make_sync_session() as session: self._mdelete(list(values.keys()), session) session.add_all( [ - LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) + LangchainKeyValueStores( + namespace=self.namespace, + key=k, + value=self._bytes_or_document(v) + ) for k, v in values.items() ] ) session.commit() + def _bytes_or_document(self, v: Union[bytes, Document]) -> Union[bytes, bytearray]: + if type(v) is bytes: + return v + elif type(v) is Document: + return bytearray(v.page_content, 'utf8') + def _mdelete(self, keys: Sequence[str], session: Session) -> None: stmt = delete(LangchainKeyValueStores).filter( and_( diff --git a/libs/community/tests/unit_tests/storage/test_sql.py b/libs/community/tests/unit_tests/storage/test_sql.py index 4744155b17a..c5e12c3aa4d 100644 --- a/libs/community/tests/unit_tests/storage/test_sql.py +++ b/libs/community/tests/unit_tests/storage/test_sql.py @@ -78,6 +78,22 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None: assert [key for key in sql_store.yield_keys()] == ["key2"] +@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues") +def test_sample_sql_docstore_with_document(sql_store: SQLStore) -> None: + # Set values for keys + sql_store.mset([("key1", Document("value1")), ("key2", Document("value2"))]) + + # Get values for keys + values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"] + assert values == [b"value1", b"value2"] + # Delete keys + sql_store.mdelete(["key1"]) + + # Iterate over keys + assert [key for key in sql_store.yield_keys()] == ["key2"] + + + @pytest.mark.requires("aiosqlite") async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: # Set values for keys