mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
Add async methods to BaseStore (#16669)
- **Description:** The BaseStore methods are currently blocking. Some implementations (AstraDBStore, RedisStore) would benefit from having async methods. Also once we have async methods for BaseStore, we can implement the async `aembed_documents` in CacheBackedEmbeddings to cache the embeddings asynchronously. * adds async methods amget, amset, amedelete and ayield_keys to BaseStore * implements the async methods for InMemoryStore * adds tests for InMemoryStore async methods - **Twitter handle:** cbornet_
This commit is contained in:
parent
17e886388b
commit
a0ec045495
@ -1,5 +1,17 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, Iterator, List, Optional, Sequence, Tuple, TypeVar, Union
|
from typing import (
|
||||||
|
AsyncIterator,
|
||||||
|
Generic,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain_core.runnables import run_in_executor
|
||||||
|
|
||||||
K = TypeVar("K")
|
K = TypeVar("K")
|
||||||
V = TypeVar("V")
|
V = TypeVar("V")
|
||||||
@ -20,6 +32,18 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
If a key is not found, the corresponding value will be None.
|
If a key is not found, the corresponding value will be None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def amget(self, keys: Sequence[K]) -> List[Optional[V]]:
|
||||||
|
"""Get the values associated with the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[K]): A sequence of keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of optional values associated with the keys.
|
||||||
|
If a key is not found, the corresponding value will be None.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, self.mget, keys)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
def mset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||||
"""Set the values for the given keys.
|
"""Set the values for the given keys.
|
||||||
@ -28,6 +52,14 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def amset(self, key_value_pairs: Sequence[Tuple[K, V]]) -> None:
|
||||||
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_value_pairs (Sequence[Tuple[K, V]]): A sequence of key-value pairs.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, self.mset, key_value_pairs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def mdelete(self, keys: Sequence[K]) -> None:
|
def mdelete(self, keys: Sequence[K]) -> None:
|
||||||
"""Delete the given keys and their associated values.
|
"""Delete the given keys and their associated values.
|
||||||
@ -36,6 +68,14 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
keys (Sequence[K]): A sequence of keys to delete.
|
keys (Sequence[K]): A sequence of keys to delete.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def amdelete(self, keys: Sequence[K]) -> None:
|
||||||
|
"""Delete the given keys and their associated values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[K]): A sequence of keys to delete.
|
||||||
|
"""
|
||||||
|
return await run_in_executor(None, self.mdelete, keys)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def yield_keys(
|
def yield_keys(
|
||||||
self, *, prefix: Optional[str] = None
|
self, *, prefix: Optional[str] = None
|
||||||
@ -52,5 +92,27 @@ class BaseStore(Generic[K, V], ABC):
|
|||||||
depending on what makes more sense for the given store.
|
depending on what makes more sense for the given store.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
async def ayield_keys(
|
||||||
|
self, *, prefix: Optional[str] = None
|
||||||
|
) -> Union[AsyncIterator[K], AsyncIterator[str]]:
|
||||||
|
"""Get an iterator over keys that match the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str): The prefix to match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Iterator[K | str]: An iterator over keys that match the given prefix.
|
||||||
|
|
||||||
|
This method is allowed to return an iterator over either K or str
|
||||||
|
depending on what makes more sense for the given store.
|
||||||
|
"""
|
||||||
|
iterator = await run_in_executor(None, self.yield_keys, prefix=prefix)
|
||||||
|
done = object()
|
||||||
|
while True:
|
||||||
|
item = await run_in_executor(None, lambda it: next(it, done), iterator)
|
||||||
|
if item is done:
|
||||||
|
break
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
ByteStore = BaseStore[str, bytes]
|
ByteStore = BaseStore[str, bytes]
|
||||||
|
@ -5,6 +5,7 @@ primarily for unit testing purposes.
|
|||||||
"""
|
"""
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Iterator,
|
Iterator,
|
||||||
@ -60,6 +61,18 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
"""
|
"""
|
||||||
return [self.store.get(key) for key in keys]
|
return [self.store.get(key) for key in keys]
|
||||||
|
|
||||||
|
async def amget(self, keys: Sequence[str]) -> List[Optional[V]]:
|
||||||
|
"""Get the values associated with the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): A sequence of keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A sequence of optional values associated with the keys.
|
||||||
|
If a key is not found, the corresponding value will be None.
|
||||||
|
"""
|
||||||
|
return self.mget(keys)
|
||||||
|
|
||||||
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
def mset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||||
"""Set the values for the given keys.
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
@ -72,6 +85,17 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
for key, value in key_value_pairs:
|
for key, value in key_value_pairs:
|
||||||
self.store[key] = value
|
self.store[key] = value
|
||||||
|
|
||||||
|
async def amset(self, key_value_pairs: Sequence[Tuple[str, V]]) -> None:
|
||||||
|
"""Set the values for the given keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_value_pairs (Sequence[Tuple[str, V]]): A sequence of key-value pairs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
return self.mset(key_value_pairs)
|
||||||
|
|
||||||
def mdelete(self, keys: Sequence[str]) -> None:
|
def mdelete(self, keys: Sequence[str]) -> None:
|
||||||
"""Delete the given keys and their associated values.
|
"""Delete the given keys and their associated values.
|
||||||
|
|
||||||
@ -82,6 +106,14 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
if key in self.store:
|
if key in self.store:
|
||||||
del self.store[key]
|
del self.store[key]
|
||||||
|
|
||||||
|
async def amdelete(self, keys: Sequence[str]) -> None:
|
||||||
|
"""Delete the given keys and their associated values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
keys (Sequence[str]): A sequence of keys to delete.
|
||||||
|
"""
|
||||||
|
self.mdelete(keys)
|
||||||
|
|
||||||
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
def yield_keys(self, prefix: Optional[str] = None) -> Iterator[str]:
|
||||||
"""Get an iterator over keys that match the given prefix.
|
"""Get an iterator over keys that match the given prefix.
|
||||||
|
|
||||||
@ -98,6 +130,23 @@ class InMemoryBaseStore(BaseStore[str, V], Generic[V]):
|
|||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
yield key
|
yield key
|
||||||
|
|
||||||
|
async def ayield_keys(self, prefix: Optional[str] = None) -> AsyncIterator[str]:
|
||||||
|
"""Get an async iterator over keys that match the given prefix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str, optional): The prefix to match. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncIterator[str]: An async iterator over keys that match the given prefix.
|
||||||
|
"""
|
||||||
|
if prefix is None:
|
||||||
|
for key in self.store.keys():
|
||||||
|
yield key
|
||||||
|
else:
|
||||||
|
for key in self.store.keys():
|
||||||
|
if key.startswith(prefix):
|
||||||
|
yield key
|
||||||
|
|
||||||
|
|
||||||
InMemoryStore = InMemoryBaseStore[Any]
|
InMemoryStore = InMemoryBaseStore[Any]
|
||||||
InMemoryByteStore = InMemoryBaseStore[bytes]
|
InMemoryByteStore = InMemoryBaseStore[bytes]
|
||||||
|
@ -13,6 +13,18 @@ def test_mget() -> None:
|
|||||||
assert non_existent_value == [None]
|
assert non_existent_value == [None]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amget() -> None:
|
||||||
|
store = InMemoryStore()
|
||||||
|
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
values = await store.amget(["key1", "key2"])
|
||||||
|
assert values == ["value1", "value2"]
|
||||||
|
|
||||||
|
# Test non-existent key
|
||||||
|
non_existent_value = await store.amget(["key3"])
|
||||||
|
assert non_existent_value == [None]
|
||||||
|
|
||||||
|
|
||||||
def test_mset() -> None:
|
def test_mset() -> None:
|
||||||
store = InMemoryStore()
|
store = InMemoryStore()
|
||||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
@ -21,6 +33,14 @@ def test_mset() -> None:
|
|||||||
assert values == ["value1", "value2"]
|
assert values == ["value1", "value2"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amset() -> None:
|
||||||
|
store = InMemoryStore()
|
||||||
|
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
values = await store.amget(["key1", "key2"])
|
||||||
|
assert values == ["value1", "value2"]
|
||||||
|
|
||||||
|
|
||||||
def test_mdelete() -> None:
|
def test_mdelete() -> None:
|
||||||
store = InMemoryStore()
|
store = InMemoryStore()
|
||||||
store.mset([("key1", "value1"), ("key2", "value2")])
|
store.mset([("key1", "value1"), ("key2", "value2")])
|
||||||
@ -34,6 +54,19 @@ def test_mdelete() -> None:
|
|||||||
store.mdelete(["key3"]) # No error should be raised
|
store.mdelete(["key3"]) # No error should be raised
|
||||||
|
|
||||||
|
|
||||||
|
async def test_amdelete() -> None:
|
||||||
|
store = InMemoryStore()
|
||||||
|
await store.amset([("key1", "value1"), ("key2", "value2")])
|
||||||
|
|
||||||
|
await store.amdelete(["key1"])
|
||||||
|
|
||||||
|
values = await store.amget(["key1", "key2"])
|
||||||
|
assert values == [None, "value2"]
|
||||||
|
|
||||||
|
# Test deleting non-existent key
|
||||||
|
await store.amdelete(["key3"]) # No error should be raised
|
||||||
|
|
||||||
|
|
||||||
def test_yield_keys() -> None:
|
def test_yield_keys() -> None:
|
||||||
store = InMemoryStore()
|
store = InMemoryStore()
|
||||||
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
store.mset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||||
@ -46,3 +79,17 @@ def test_yield_keys() -> None:
|
|||||||
|
|
||||||
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
keys_with_invalid_prefix = list(store.yield_keys(prefix="x"))
|
||||||
assert keys_with_invalid_prefix == []
|
assert keys_with_invalid_prefix == []
|
||||||
|
|
||||||
|
|
||||||
|
async def test_ayield_keys() -> None:
|
||||||
|
store = InMemoryStore()
|
||||||
|
await store.amset([("key1", "value1"), ("key2", "value2"), ("key3", "value3")])
|
||||||
|
|
||||||
|
keys = [key async for key in store.ayield_keys()]
|
||||||
|
assert set(keys) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_prefix = [key async for key in store.ayield_keys(prefix="key")]
|
||||||
|
assert set(keys_with_prefix) == {"key1", "key2", "key3"}
|
||||||
|
|
||||||
|
keys_with_invalid_prefix = [key async for key in store.ayield_keys(prefix="x")]
|
||||||
|
assert keys_with_invalid_prefix == []
|
||||||
|
Loading…
Reference in New Issue
Block a user