From 40979a0063850cf4eac92b389aec397799d89163 Mon Sep 17 00:00:00 2001 From: Alex Lee Date: Sat, 15 Mar 2025 16:46:54 -0700 Subject: [PATCH] resolve sqldocstore postgres compatibility issue --- .../langchain_community/storage/sql.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/libs/community/langchain_community/storage/sql.py b/libs/community/langchain_community/storage/sql.py index c5b4ae978ff..3d821a09c51 100644 --- a/libs/community/langchain_community/storage/sql.py +++ b/libs/community/langchain_community/storage/sql.py @@ -184,9 +184,8 @@ class SQLStore(BaseStore[str, bytes]): LangchainKeyValueStores.namespace == self.namespace, ) ) - for v in await session.scalars(stmt): - result[v.key] = v.value - return [result.get(key) for key in keys] + results = session.execute(stmt).all() + return [result[0].value for result in results] def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]: result = {} @@ -198,28 +197,27 @@ class SQLStore(BaseStore[str, bytes]): LangchainKeyValueStores.namespace == self.namespace, ) ) - for v in session.scalars(stmt): - result[v.key] = v.value - return [result.get(key) for key in keys] + results = session.execute(stmt).all() + return [result[0].value for result in results] - async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]], encoding: str = 'utf8') -> None: async with self._make_async_session() as session: await self._amdelete([key for key, _ in key_value_pairs], session) session.add_all( [ - LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) + LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) for k, v in key_value_pairs ] ) await session.commit() - def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None: + def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]], , encoding: str = 'utf8') -> None: values: Dict[str, bytes] = dict(key_value_pairs) with self._make_sync_session() as session: self._mdelete(list(values.keys()), session) session.add_all( [ - LangchainKeyValueStores(namespace=self.namespace, key=k, value=v) + LangchainKeyValueStores(namespace=self.namespace, key=k, value=bytearray(v, encoding)) for k, v in values.items() ] )