mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 00:51:18 +00:00
ParentChildRetriever support for sqldocstore
This commit is contained in:
parent
40979a0063
commit
3b5019dcdb
@ -15,6 +15,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from langchain_core.documents.base import Document
|
||||||
from langchain_core.stores import BaseStore
|
from langchain_core.stores import BaseStore
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
LargeBinary,
|
LargeBinary,
|
||||||
@ -176,7 +177,6 @@ class SQLStore(BaseStore[str, bytes]):
|
|||||||
|
|
||||||
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||||
assert isinstance(self.engine, AsyncEngine)
|
assert isinstance(self.engine, AsyncEngine)
|
||||||
result: Dict[str, bytes] = {}
|
|
||||||
async with self._make_async_session() as session:
|
async with self._make_async_session() as session:
|
||||||
stmt = select(LangchainKeyValueStores).filter(
|
stmt = select(LangchainKeyValueStores).filter(
|
||||||
and_(
|
and_(
|
||||||
@ -188,8 +188,6 @@ class SQLStore(BaseStore[str, bytes]):
|
|||||||
return [result[0].value for result in results]
|
return [result[0].value for result in results]
|
||||||
|
|
||||||
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
|
||||||
result = {}
|
|
||||||
|
|
||||||
with self._make_sync_session() as session:
|
with self._make_sync_session() as session:
|
||||||
stmt = select(LangchainKeyValueStores).filter(
|
stmt = select(LangchainKeyValueStores).filter(
|
||||||
and_(
|
and_(
|
||||||
@ -200,29 +198,43 @@ class SQLStore(BaseStore[str, bytes]):
|
|||||||
results = session.execute(stmt).all()
|
results = session.execute(stmt).all()
|
||||||
return [result[0].value for result in results]
|
return [result[0].value for result in results]
|
||||||
|
|
||||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None:
|
async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||||
async with self._make_async_session() as session:
|
async with self._make_async_session() as session:
|
||||||
await self._amdelete([key for key, _ in key_value_pairs], session)
|
await self._amdelete([key for key, _ in key_value_pairs], session)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
|
LangchainKeyValueStores(
|
||||||
|
namespace=self.namespace,
|
||||||
|
key=k,
|
||||||
|
value=self._bytes_or_document(v)
|
||||||
|
)
|
||||||
for k, v in key_value_pairs
|
for k, v in key_value_pairs
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None:
|
def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
|
||||||
values: Dict[str, bytes] = dict(key_value_pairs)
|
values: Dict[str, bytes] = dict(key_value_pairs)
|
||||||
with self._make_sync_session() as session:
|
with self._make_sync_session() as session:
|
||||||
self._mdelete(list(values.keys()), session)
|
self._mdelete(list(values.keys()), session)
|
||||||
session.add_all(
|
session.add_all(
|
||||||
[
|
[
|
||||||
LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding))
|
LangchainKeyValueStores(
|
||||||
|
namespace=self.namespace,
|
||||||
|
key=k,
|
||||||
|
value=self._bytes_or_document(v)
|
||||||
|
)
|
||||||
for k, v in values.items()
|
for k, v in values.items()
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def _bytes_or_document(self, v: Union[bytes, Document]) -> Union[bytes, bytearray]:
|
||||||
|
if type(v) is bytes:
|
||||||
|
return v
|
||||||
|
elif type(v) is Document:
|
||||||
|
return bytearray(v.page_content, 'utf8')
|
||||||
|
|
||||||
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
def _mdelete(self, keys: Sequence[str], session: Session) -> None:
|
||||||
stmt = delete(LangchainKeyValueStores).filter(
|
stmt = delete(LangchainKeyValueStores).filter(
|
||||||
and_(
|
and_(
|
||||||
|
@ -78,6 +78,22 @@ def test_sample_sql_docstore(sql_store: SQLStore) -> None:
|
|||||||
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
|
||||||
|
def test_sample_sql_docstore_with_document(sql_store: SQLStore) -> None:
|
||||||
|
# Set values for keys
|
||||||
|
sql_store.mset([("key1", Document("value1")), ("key2", Document("value2"))])
|
||||||
|
|
||||||
|
# Get values for keys
|
||||||
|
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
|
||||||
|
assert values == [b"value1", b"value2"]
|
||||||
|
# Delete keys
|
||||||
|
sql_store.mdelete(["key1"])
|
||||||
|
|
||||||
|
# Iterate over keys
|
||||||
|
assert [key for key in sql_store.yield_keys()] == ["key2"]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("aiosqlite")
|
@pytest.mark.requires("aiosqlite")
|
||||||
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
|
async def test_async_sample_sql_docstore(async_sql_store: SQLStore) -> None:
|
||||||
# Set values for keys
|
# Set values for keys
|
||||||
|
Loading…
Reference in New Issue
Block a user