ParentChildRetriever support for sqldocstore

This commit is contained in:
Alex Lee 2025-03-15 18:19:53 -07:00
parent 40979a0063
commit 3b5019dcdb
2 changed files with 35 additions and 7 deletions

View File

@ -15,6 +15,7 @@ from typing import (
cast,
)
from langchain_core.documents.base import Document
from langchain_core.stores import BaseStore
from sqlalchemy import (
LargeBinary,
@ -176,7 +177,6 @@ class SQLStore(BaseStore[str, bytes]):
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
assert isinstance(self.engine, AsyncEngine)
result: Dict[str, bytes] = {}
async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
@ -188,8 +188,6 @@ class SQLStore(BaseStore[str, bytes]):
return [result[0].value for result in results]
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
result = {}
with self._make_sync_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
@ -200,29 +198,43 @@ class SQLStore(BaseStore[str, bytes]):
results = session.execute(stmt).all()
return [result[0].value for result in results]
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None:
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
async with self._make_async_session() as session:
await self._amdelete([key for key, _ in key_value_pairs], session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
LangchainKeyValueStores(
namespace=self.namespace,
key=k,
value=self._bytes_or_document(v)
)
for k, v in key_value_pairs
]
)
await session.commit()
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None:
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
values: Dict[str, bytes] = dict(key_value_pairs)
with self._make_sync_session() as session:
self._mdelete(list(values.keys()), session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
LangchainKeyValueStores(
namespace=self.namespace,
key=k,
value=self._bytes_or_document(v)
)
for k, v in values.items()
]
)
session.commit()
def _bytes_or_document(self, v: Union[bytes, Document]) -> Union[bytes, bytearray]:
if type(v) is bytes:
return v
elif type(v) is Document:
return bytearray(v.page_content, 'utf8')
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
stmt = delete(LangchainKeyValueStores).filter(
and_(

View File

@ -78,6 +78,22 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None:
assert [key for key in sql_store.yield_keys()] == ["key2"]
@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_sample_sql_docstore_with_document(sql_store: SQLStore) -> None:
# Set values for keys
sql_store.mset([("key1", Document("value1")), ("key2", Document("value2"))])
# Get values for keys
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
assert values == [b"value1", b"value2"]
# Delete keys
sql_store.mdelete(["key1"])
# Iterate over keys
assert [key for key in sql_store.yield_keys()] == ["key2"]
@pytest.mark.requires("aiosqlite")
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
# Set values for keys