fix(storage): Fix load too many chunks error (#2614)

# Description

Closes #2607

# How Has This Been Tested?

Please describe the tests that you ran to verify your changes. Provide
instructions so we can reproduce. Please also list any relevant details
for your test configuration

# Snapshots:

Include snapshots for easier review.

# Checklist:

- [x] My code follows the style guidelines of this project
- [x] I have already rebased the commits and make the commit message
conform to the project standard.
- [x] I have performed a self-review of my own code
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] Any dependent changes have been merged and published in downstream
modules
This commit is contained in:
magic.chen 2025-04-12 20:48:59 +08:00 committed by GitHub
commit 0245cfdff8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 88 additions and 17 deletions

View File

@ -149,7 +149,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
self._save_memory_stream()
# Add to vector store
return self._index_store.load_document(dup_docs)
return self._index_store.load_document_with_limit(dup_docs)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None

View File

@ -27,9 +27,16 @@ class IndexStoreConfig(BaseParameters):
class IndexStoreBase(ABC):
"""Index store base class."""
def __init__(self, executor: Optional[Executor] = None):
def __init__(
self,
executor: Optional[Executor] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()
self._max_chunks_once_load = max_chunks_once_load or 10
self._max_threads = max_threads or 1
@abstractmethod
def get_config(self) -> IndexStoreConfig:
@ -102,7 +109,10 @@ class IndexStoreBase(ABC):
return True
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
self,
chunks: List[Chunk],
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> List[str]:
"""Load document in index database with specified limit.
@ -114,6 +124,8 @@ class IndexStoreBase(ABC):
Return:
List[str]: Chunk ids.
"""
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
max_threads = max_threads or self._max_threads
# Group the chunks into chunks of size max_chunks
chunk_groups = [
chunks[i : i + max_chunks_once_load]
@ -141,7 +153,10 @@ class IndexStoreBase(ABC):
return ids
async def aload_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
self,
chunks: List[Chunk],
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> List[str]:
"""Load document in index database with specified limit.
@ -153,6 +168,8 @@ class IndexStoreBase(ABC):
Return:
List[str]: Chunk ids.
"""
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
max_threads = max_threads or self._max_threads
chunk_groups = [
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)

View File

@ -88,6 +88,24 @@ class VectorStoreConfig(IndexStoreConfig, RegisterParameters):
),
},
)
max_chunks_once_load: Optional[int] = field(
default=None,
metadata={
"help": _(
"The max chunks once load in vector store, "
"if not set, will use the default value 10."
),
},
)
max_threads: Optional[int] = field(
default=None,
metadata={
"help": _(
"The max threads in vector store, "
"if not set, will use the default value 1."
),
},
)
def create_store(self, **kwargs) -> "VectorStoreBase":
"""Create a new index store from the config."""
@ -97,9 +115,16 @@ class VectorStoreConfig(IndexStoreConfig, RegisterParameters):
class VectorStoreBase(IndexStoreBase, ABC):
"""Vector store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
def __init__(
self,
executor: Optional[ThreadPoolExecutor] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
):
"""Initialize vector store."""
super().__init__(executor)
super().__init__(
executor, max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
@abstractmethod
def get_config(self) -> VectorStoreConfig:

View File

@ -131,8 +131,8 @@ class EmbeddingAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
max_threads = kwargs.get("max_threads", 1)
max_chunks_once_load = kwargs.get("max_chunks_once_load")
max_threads = kwargs.get("max_threads")
return self._index_store.load_document_with_limit(
self._chunks, max_chunks_once_load, max_threads
)
@ -144,8 +144,8 @@ class EmbeddingAssembler(BaseAssembler):
List[str]: List of chunk ids.
"""
# persist chunks into vector store
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
max_threads = kwargs.get("max_threads", 1)
max_chunks_once_load = kwargs.get("max_chunks_once_load")
max_threads = kwargs.get("max_threads")
return await self._index_store.aload_document_with_limit(
self._chunks, max_chunks_once_load, max_threads
)

View File

@ -91,6 +91,8 @@ class ChromaStore(VectorStoreBase):
embedding_fn: Optional[Embeddings] = None,
chroma_client: Optional["PersistentClient"] = None, # type: ignore # noqa
collection_metadata: Optional[dict] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a ChromaStore instance.
@ -100,8 +102,12 @@ class ChromaStore(VectorStoreBase):
embedding_fn(Embeddings): embedding function.
chroma_client(PersistentClient): chroma client.
collection_metadata(dict): collection metadata.
max_chunks_once_load(int): max chunks once load.
max_threads(int): max threads.
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
try:
from chromadb import PersistentClient, Settings

View File

@ -157,13 +157,17 @@ class ElasticStore(VectorStoreBase):
vector_store_config: ElasticsearchStoreConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a ElasticsearchStore instance.
Args:
vector_store_config (ElasticsearchStoreConfig): ElasticsearchStore config.
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
connect_kwargs = {}

View File

@ -197,6 +197,8 @@ class MilvusStore(VectorStoreBase):
vector_store_config: MilvusVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a MilvusStore instance.
@ -204,7 +206,9 @@ class MilvusStore(VectorStoreBase):
vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
try:

View File

@ -192,6 +192,8 @@ class OceanBaseStore(VectorStoreBase):
vector_store_config: OceanBaseConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a OceanBaseStore instance."""
try:
@ -205,7 +207,9 @@ class OceanBaseStore(VectorStoreBase):
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.embedding_function = embedding_fn

View File

@ -85,6 +85,8 @@ class PGVectorStore(VectorStoreBase):
vector_store_config: PGVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a PGVectorStore instance."""
try:
@ -93,7 +95,9 @@ class PGVectorStore(VectorStoreBase):
raise ImportError(
"Please install the `langchain` package to use the PGVector."
)
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.connection_string = vector_store_config.connection_string

View File

@ -96,6 +96,8 @@ class WeaviateStore(VectorStoreBase):
vector_store_config: WeaviateVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Initialize with Weaviate client."""
try:
@ -105,7 +107,9 @@ class WeaviateStore(VectorStoreBase):
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.weaviate_url = vector_store_config.weaviate_url

View File

@ -68,7 +68,10 @@ class StorageManager(BaseComponent):
embedding_fn = embedding_factory.create()
vector_store_config: VectorStoreConfig = storage_config.vector
return vector_store_config.create_store(
name=index_name, embedding_fn=embedding_fn
name=index_name,
embedding_fn=embedding_fn,
max_chunks_once_load=vector_store_config.max_chunks_once_load,
max_threads=vector_store_config.max_threads,
)
def create_kg_store(