mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 13:23:35 +00:00
Add async methods to BaseStore (#16669)
- **Description:** The BaseStore methods are currently blocking. Some implementations (AstraDBStore, RedisStore) would benefit from having async methods. Also once we have async methods for BaseStore, we can implement the async `aembed_documents` in CacheBackedEmbeddings to cache the embeddings asynchronously. * adds async methods amget, amset, amedelete and ayield_keys to BaseStore * implements the async methods for InMemoryStore * adds tests for InMemoryStore async methods - **Twitter handle:** cbornet_
This commit is contained in:
parent
17e886388b
commit
a0ec045495
@ -1,5 +1,17 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
|
||||
from typing import (
|
||||
AsyncIterator,
|
||||
Generic,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.runnables import run_in_executor
|
||||
|
||||
K = TypeVar("K")
|
||||
V = TypeVar("V")
|
||||
@ -20,6 +32,18 @@ class BaseStore(Generic[K, V], ABC):
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
|
||||
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[K]): A sequence of keys.
|
||||
|
||||
Returns:
|
||||
A sequence of optional values associated with the keys.
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
return await run_in_executor(None, self.mget, keys)
|
||||
|
||||
@abstractmethod
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
@ -28,6 +52,14 @@ class BaseStore(Generic[K, V], ABC):
|
||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||
"""
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||
"""
|
||||
return await run_in_executor(None, self.mset, key_value_pairs)
|
||||
|
||||
@abstractmethod
|
||||
def mdelete(self, keys: Sequence[K]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
@ -36,6 +68,14 @@ class BaseStore(Generic[K, V], ABC):
|
||||
keys (Sequence[K]): A sequence of keys to delete.
|
||||
"""
|
||||
|
||||
async def amdelete(self, keys: Sequence[K]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[K]): A sequence of keys to delete.
|
||||
"""
|
||||
return await run_in_executor(None, self.mdelete, keys)
|
||||
|
||||
@abstractmethod
|
||||
def yield_keys(
|
||||
self, *, prefix: Optional[str] = None
|
||||
@ -52,5 +92,27 @@ class BaseStore(Generic[K, V], ABC):
|
||||
depending on what makes more sense for the given store.
|
||||
"""
|
||||
|
||||
async def ayield_keys(
|
||||
self, *, prefix: Optional[str] = None
|
||||
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str): The prefix to match.
|
||||
|
||||
Returns:
|
||||
Iterator[K | str]: An iterator over keys that match the given prefix.
|
||||
|
||||
This method is allowed to return an iterator over either K or str
|
||||
depending on what makes more sense for the given store.
|
||||
"""
|
||||
iterator = await run_in_executor(None, self.yield_keys, prefix=prefix)
|
||||
done = object()
|
||||
while True:
|
||||
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
||||
if item is done:
|
||||
break
|
||||
yield item
|
||||
|
||||
|
||||
ByteStore = BaseStore[str, bytes]
|
||||
|
@ -5,6 +5,7 @@ primarily for unit testing purposes.
|
||||
"""
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterator,
|
||||
@ -60,6 +61,18 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
"""
|
||||
return [self.store.get(key) for key in keys]
|
||||
|
||||
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||
"""Get the values associated with the given keys.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys.
|
||||
|
||||
Returns:
|
||||
A sequence of optional values associated with the keys.
|
||||
If a key is not found, the corresponding value will be None.
|
||||
"""
|
||||
return self.mget(keys)
|
||||
|
||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
@ -72,6 +85,17 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
for key, value in key_value_pairs:
|
||||
self.store[key] = value
|
||||
|
||||
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||
"""Set the values for the given keys.
|
||||
|
||||
Args:
|
||||
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
return self.mset(key_value_pairs)
|
||||
|
||||
def mdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
@ -82,6 +106,14 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
if key in self.store:
|
||||
del self.store[key]
|
||||
|
||||
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||
"""Delete the given keys and their associated values.
|
||||
|
||||
Args:
|
||||
keys (Sequence[str]): A sequence of keys to delete.
|
||||
"""
|
||||
self.mdelete(keys)
|
||||
|
||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||
"""Get an iterator over keys that match the given prefix.
|
||||
|
||||
@ -98,6 +130,23 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
||||
if key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||
"""Get an async iterator over keys that match the given prefix.
|
||||
|
||||
Args:
|
||||
prefix (str, optional): The prefix to match. Defaults to None.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[str]: An async iterator over keys that match the given prefix.
|
||||
"""
|
||||
if prefix is None:
|
||||
for key in self.store.keys():
|
||||
yield key
|
||||
else:
|
||||
for key in self.store.keys():
|
||||
if key.startswith(prefix):
|
||||
yield key
|
||||
|
||||
|
||||
InMemoryStore = InMemoryBaseStore[Any]
|
||||
InMemoryByteStore = InMemoryBaseStore[bytes]
|
||||
|
@ -13,6 +13,18 @@ def test_mget() -> None:
|
||||
assert non_existent_value == [None]
|
||||
|
||||
|
||||
async def test_amget() -> None:
|
||||
store = InMemoryStore()
|
||||
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
values = await store.amget(["key1", "key2"])
|
||||
assert values == ["value1", "value2"]
|
||||
|
||||
# Test non-existent key
|
||||
non_existent_value = await store.amget(["key3"])
|
||||
assert non_existent_value == [None]
|
||||
|
||||
|
||||
def test_mset() -> None:
|
||||
store = InMemoryStore()
|
||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
@ -21,6 +33,14 @@ def test_mset() -> None:
|
||||
assert values == ["value1", "value2"]
|
||||
|
||||
|
||||
async def test_amset() -> None:
|
||||
store = InMemoryStore()
|
||||
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
values = await store.amget(["key1", "key2"])
|
||||
assert values == ["value1", "value2"]
|
||||
|
||||
|
||||
def test_mdelete() -> None:
|
||||
store = InMemoryStore()
|
||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||
@ -34,6 +54,19 @@ def test_mdelete() -> None:
|
||||
store.mdelete(["key3"]) # No error should be raised
|
||||
|
||||
|
||||
async def test_amdelete() -> None:
|
||||
store = InMemoryStore()
|
||||
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||
|
||||
await store.amdelete(["key1"])
|
||||
|
||||
values = await store.amget(["key1", "key2"])
|
||||
assert values == [None, "value2"]
|
||||
|
||||
# Test deleting non-existent key
|
||||
await store.amdelete(["key3"]) # No error should be raised
|
||||
|
||||
|
||||
def test_yield_keys() -> None:
|
||||
store = InMemoryStore()
|
||||
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||
@ -46,3 +79,17 @@ def test_yield_keys() -> None:
|
||||
|
||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||
assert keys_with_invalid_prefix == []
|
||||
|
||||
|
||||
async def test_ayield_keys() -> None:
|
||||
store = InMemoryStore()
|
||||
await store.amset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||
|
||||
keys = [key async for key in store.ayield_keys()]
|
||||
assert set(keys) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_prefix = [key async for key in store.ayield_keys(prefix="key")]
|
||||
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||
|
||||
keys_with_invalid_prefix = [key async for key in store.ayield_keys(prefix="x")]
|
||||
assert keys_with_invalid_prefix == []
|
||||
|
Loading…
Reference in New Issue
Block a user