ParentChildRetriever support for sqldocstore

This commit is contained in:
Alex Lee 2025-03-15 18:19:53 -07:00
parent 40979a0063
commit 3b5019dcdb
2 changed files with 35 additions and 7 deletions

View File

@ -15,6 +15,7 @@ from typing import (
cast, cast,
) )
from langchain_core.documents.base import Document
from langchain_core.stores import BaseStore from langchain_core.stores import BaseStore
from sqlalchemy import ( from sqlalchemy import (
LargeBinary, LargeBinary,
@ -176,7 +177,6 @@ class SQLStore(BaseStore[str, bytes]):
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]: async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
assert isinstance(self.engine, AsyncEngine) assert isinstance(self.engine, AsyncEngine)
result: Dict[str, bytes] = {}
async with self._make_async_session() as session: async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter( stmt = select(LangchainKeyValueStores).filter(
and_( and_(
@ -188,8 +188,6 @@ class SQLStore(BaseStore[str, bytes]):
return [result[0].value for result in results] return [result[0].value for result in results]
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
result = {}
with self._make_sync_session() as session: with self._make_sync_session() as session:
stmt = select(LangchainKeyValueStores).filter( stmt = select(LangchainKeyValueStores).filter(
and_( and_(
@ -200,29 +198,43 @@ class SQLStore(BaseStore[str, bytes]):
results = session.execute(stmt).all() results = session.execute(stmt).all()
return [result[0].value for result in results] return [result[0].value for result in results]
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None: async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
async with self._make_async_session() as session: async with self._make_async_session() as session:
await self._amdelete([key for key, _ in key_value_pairs], session) await self._amdelete([key for key, _ in key_value_pairs], session)
session.add_all( session.add_all(
[ [
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) LangchainKeyValueStores(
namespace=self.namespace,
key=k,
value=self._bytes_or_document(v)
)
for k, v in key_value_pairs for k, v in key_value_pairs
] ]
) )
await session.commit() await session.commit()
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None: def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
values: Dict[str, bytes] = dict(key_value_pairs) values: Dict[str, bytes] = dict(key_value_pairs)
with self._make_sync_session() as session: with self._make_sync_session() as session:
self._mdelete(list(values.keys()), session) self._mdelete(list(values.keys()), session)
session.add_all( session.add_all(
[ [
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) LangchainKeyValueStores(
namespace=self.namespace,
key=k,
value=self._bytes_or_document(v)
)
for k, v in values.items() for k, v in values.items()
] ]
) )
session.commit() session.commit()
def _bytes_or_document(self, v: Union[bytes, Document]) -> Union[bytes, bytearray]:
if type(v) is bytes:
return v
elif type(v) is Document:
return bytearray(v.page_content, 'utf8')
def _mdelete(self, keys: Sequence[str], session: Session) -> None: def _mdelete(self, keys: Sequence[str], session: Session) -> None:
stmt = delete(LangchainKeyValueStores).filter( stmt = delete(LangchainKeyValueStores).filter(
and_( and_(

View File

@ -78,6 +78,22 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None:
assert [key for key in sql_store.yield_keys()] == ["key2"] assert [key for key in sql_store.yield_keys()] == ["key2"]
@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_sample_sql_docstore_with_document(sql_store: SQLStore) -> None:
# Set values for keys
sql_store.mset([("key1", Document("value1")), ("key2", Document("value2"))])
# Get values for keys
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
assert values == [b"value1", b"value2"]
# Delete keys
sql_store.mdelete(["key1"])
# Iterate over keys
assert [key for key in sql_store.yield_keys()] == ["key2"]
@pytest.mark.requires("aiosqlite") @pytest.mark.requires("aiosqlite")
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None: async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
# Set values for keys # Set values for keys