fix(storage): Fix load big documents error

This commit is contained in:
Fangyin Cheng 2025-04-11 20:46:48 +08:00
parent 12170e2504
commit c04e3c7cb0
11 changed files with 88 additions and 17 deletions

View File

@ -149,7 +149,7 @@ class TimeWeightedEmbeddingRetriever(EmbeddingRetriever):
self._save_memory_stream()
# Add to vector store
return self._index_store.load_document(dup_docs)
return self._index_store.load_document_with_limit(dup_docs)
def _retrieve(
self, query: str, filters: Optional[MetadataFilters] = None

View File

@ -27,9 +27,16 @@ class IndexStoreConfig(BaseParameters):
class IndexStoreBase(ABC):
"""Index store base class."""
def __init__(self, executor: Optional[Executor] = None):
def __init__(
self,
executor: Optional[Executor] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
):
"""Init index store."""
self._executor = executor or ThreadPoolExecutor()
self._max_chunks_once_load = max_chunks_once_load or 10
self._max_threads = max_threads or 1
@abstractmethod
def get_config(self) -> IndexStoreConfig:
@ -102,7 +109,10 @@ class IndexStoreBase(ABC):
return True
def load_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
self,
chunks: List[Chunk],
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> List[str]:
"""Load document in index database with specified limit.
@ -114,6 +124,8 @@ class IndexStoreBase(ABC):
Return:
List[str]: Chunk ids.
"""
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
max_threads = max_threads or self._max_threads
# Group the chunks into chunks of size max_chunks
chunk_groups = [
chunks[i : i + max_chunks_once_load]
@ -141,7 +153,10 @@ class IndexStoreBase(ABC):
return ids
async def aload_document_with_limit(
self, chunks: List[Chunk], max_chunks_once_load: int = 10, max_threads: int = 1
self,
chunks: List[Chunk],
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> List[str]:
"""Load document in index database with specified limit.
@ -153,6 +168,8 @@ class IndexStoreBase(ABC):
Return:
List[str]: Chunk ids.
"""
max_chunks_once_load = max_chunks_once_load or self._max_chunks_once_load
max_threads = max_threads or self._max_threads
chunk_groups = [
chunks[i : i + max_chunks_once_load]
for i in range(0, len(chunks), max_chunks_once_load)

View File

@ -88,6 +88,24 @@ class VectorStoreConfig(IndexStoreConfig, RegisterParameters):
),
},
)
max_chunks_once_load: Optional[int] = field(
default=None,
metadata={
"help": _(
"The max chunks once load in vector store, "
"if not set, will use the default value 10."
),
},
)
max_threads: Optional[int] = field(
default=None,
metadata={
"help": _(
"The max threads in vector store, "
"if not set, will use the default value 1."
),
},
)
def create_store(self, **kwargs) -> "VectorStoreBase":
"""Create a new index store from the config."""
@ -97,9 +115,16 @@ class VectorStoreConfig(IndexStoreConfig, RegisterParameters):
class VectorStoreBase(IndexStoreBase, ABC):
"""Vector store base class."""
def __init__(self, executor: Optional[ThreadPoolExecutor] = None):
def __init__(
self,
executor: Optional[ThreadPoolExecutor] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
):
"""Initialize vector store."""
super().__init__(executor)
super().__init__(
executor, max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
@abstractmethod
def get_config(self) -> VectorStoreConfig:

View File

@ -131,8 +131,8 @@ class EmbeddingAssembler(BaseAssembler):
Returns:
List[str]: List of chunk ids.
"""
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
max_threads = kwargs.get("max_threads", 1)
max_chunks_once_load = kwargs.get("max_chunks_once_load")
max_threads = kwargs.get("max_threads")
return self._index_store.load_document_with_limit(
self._chunks, max_chunks_once_load, max_threads
)
@ -144,8 +144,8 @@ class EmbeddingAssembler(BaseAssembler):
List[str]: List of chunk ids.
"""
# persist chunks into vector store
max_chunks_once_load = kwargs.get("max_chunks_once_load", 10)
max_threads = kwargs.get("max_threads", 1)
max_chunks_once_load = kwargs.get("max_chunks_once_load")
max_threads = kwargs.get("max_threads")
return await self._index_store.aload_document_with_limit(
self._chunks, max_chunks_once_load, max_threads
)

View File

@ -91,6 +91,8 @@ class ChromaStore(VectorStoreBase):
embedding_fn: Optional[Embeddings] = None,
chroma_client: Optional["PersistentClient"] = None, # type: ignore # noqa
collection_metadata: Optional[dict] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a ChromaStore instance.
@ -100,8 +102,12 @@ class ChromaStore(VectorStoreBase):
embedding_fn(Embeddings): embedding function.
chroma_client(PersistentClient): chroma client.
collection_metadata(dict): collection metadata.
max_chunks_once_load(int): max chunks once load.
max_threads(int): max threads.
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
try:
from chromadb import PersistentClient, Settings

View File

@ -157,13 +157,17 @@ class ElasticStore(VectorStoreBase):
vector_store_config: ElasticsearchStoreConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a ElasticsearchStore instance.
Args:
vector_store_config (ElasticsearchStoreConfig): ElasticsearchStore config.
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
connect_kwargs = {}

View File

@ -197,6 +197,8 @@ class MilvusStore(VectorStoreBase):
vector_store_config: MilvusVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a MilvusStore instance.
@ -204,7 +206,9 @@ class MilvusStore(VectorStoreBase):
vector_store_config (MilvusVectorConfig): MilvusStore config.
refer to https://milvus.io/docs/v2.0.x/manage_connection.md
"""
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
try:

View File

@ -192,6 +192,8 @@ class OceanBaseStore(VectorStoreBase):
vector_store_config: OceanBaseConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a OceanBaseStore instance."""
try:
@ -205,7 +207,9 @@ class OceanBaseStore(VectorStoreBase):
if vector_store_config.embedding_fn is None:
raise ValueError("embedding_fn is required for OceanBaseStore")
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.embedding_function = embedding_fn

View File

@ -85,6 +85,8 @@ class PGVectorStore(VectorStoreBase):
vector_store_config: PGVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Create a PGVectorStore instance."""
try:
@ -93,7 +95,9 @@ class PGVectorStore(VectorStoreBase):
raise ImportError(
"Please install the `langchain` package to use the PGVector."
)
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.connection_string = vector_store_config.connection_string

View File

@ -96,6 +96,8 @@ class WeaviateStore(VectorStoreBase):
vector_store_config: WeaviateVectorConfig,
name: Optional[str],
embedding_fn: Optional[Embeddings] = None,
max_chunks_once_load: Optional[int] = None,
max_threads: Optional[int] = None,
) -> None:
"""Initialize with Weaviate client."""
try:
@ -105,7 +107,9 @@ class WeaviateStore(VectorStoreBase):
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
super().__init__()
super().__init__(
max_chunks_once_load=max_chunks_once_load, max_threads=max_threads
)
self._vector_store_config = vector_store_config
self.weaviate_url = vector_store_config.weaviate_url

View File

@ -68,7 +68,10 @@ class StorageManager(BaseComponent):
embedding_fn = embedding_factory.create()
vector_store_config: VectorStoreConfig = storage_config.vector
return vector_store_config.create_store(
name=index_name, embedding_fn=embedding_fn
name=index_name,
embedding_fn=embedding_fn,
max_chunks_once_load=vector_store_config.max_chunks_once_load,
max_threads=vector_store_config.max_threads,
)
def create_kg_store(