mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
ParentChildRetriever support for sqldocstore
This commit is contained in:
parent
40979a0063
commit
3b5019dcdb
@ -15,6 +15,7 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain_core.documents.base import Document
|
||||
from langchain_core.stores import BaseStore
|
||||
from sqlalchemy import (
|
||||
LargeBinary,
|
||||
@ -176,7 +177,6 @@ class SQLStore(BaseStore[str, bytes]):
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
assert isinstance(self.engine, AsyncEngine)
|
||||
result: Dict[str, bytes] = {}
|
||||
async with self._make_async_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
@ -188,8 +188,6 @@ class SQLStore(BaseStore[str, bytes]):
|
||||
return [result[0].value for result in results]
|
||||
|
||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||
result = {}
|
||||
|
||||
with self._make_sync_session() as session:
|
||||
stmt = select(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
@ -200,29 +198,43 @@ class SQLStore(BaseStore[str, bytes]):
|
||||
results = session.execute(stmt).all()
|
||||
return [result[0].value for result in results]
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None:
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
async with self._make_async_session() as session:
|
||||
await self._amdelete([key for key, _ in key_value_pairs], session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
|
||||
LangchainKeyValueStores(
|
||||
namespace=self.namespace,
|
||||
key=k,
|
||||
value=self._bytes_or_document(v)
|
||||
)
|
||||
for k, v in key_value_pairs
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None:
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||
values: Dict[str, bytes] = dict(key_value_pairs)
|
||||
with self._make_sync_session() as session:
|
||||
self._mdelete(list(values.keys()), session)
|
||||
session.add_all(
|
||||
[
|
||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
|
||||
LangchainKeyValueStores(
|
||||
namespace=self.namespace,
|
||||
key=k,
|
||||
value=self._bytes_or_document(v)
|
||||
)
|
||||
for k, v in values.items()
|
||||
]
|
||||
)
|
||||
session.commit()
|
||||
|
||||
def _bytes_or_document(self, v: Union[bytes, Document]) -> Union[bytes, bytearray]:
|
||||
if type(v) is bytes:
|
||||
return v
|
||||
elif type(v) is Document:
|
||||
return bytearray(v.page_content, 'utf8')
|
||||
|
||||
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
||||
stmt = delete(LangchainKeyValueStores).filter(
|
||||
and_(
|
||||
|
@ -78,6 +78,22 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None:
|
||||
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
||||
|
||||
|
||||
@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
|
||||
def test_sample_sql_docstore_with_document(sql_store: SQLStore) -> None:
|
||||
# Set values for keys
|
||||
sql_store.mset([("key1", Document("value1")), ("key2", Document("value2"))])
|
||||
|
||||
# Get values for keys
|
||||
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||
assert values == [b"value1", b"value2"]
|
||||
# Delete keys
|
||||
sql_store.mdelete(["key1"])
|
||||
|
||||
# Iterate over keys
|
||||
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.requires("aiosqlite")
|
||||
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
|
||||
# Set values for keys
|
||||
|
Loading…
Reference in New Issue
Block a user