mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-09-23 12:07:12 +00:00
fix: chromadb max batch size (#1087)
This commit is contained in:
87
private_gpt/components/vector_store/batched_chroma.py
Normal file
87
private_gpt/components/vector_store/batched_chroma.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Any
|
||||
|
||||
from llama_index.schema import BaseNode, MetadataMode
|
||||
from llama_index.vector_stores import ChromaVectorStore
|
||||
from llama_index.vector_stores.chroma import chunk_list
|
||||
from llama_index.vector_stores.utils import node_to_metadata_dict
|
||||
|
||||
|
||||
class BatchedChromaVectorStore(ChromaVectorStore):
|
||||
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
|
||||
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
|
||||
Args:
|
||||
chroma_client (from chromadb.api.API):
|
||||
API instance
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
|
||||
"""
|
||||
|
||||
chroma_client: Any | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_client: Any,
|
||||
chroma_collection: Any,
|
||||
host: str | None = None,
|
||||
port: str | None = None,
|
||||
ssl: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
collection_kwargs: dict[Any, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
chroma_collection=chroma_collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
self.chroma_client = chroma_client
|
||||
|
||||
def add(self, nodes: list[BaseNode]) -> list[str]:
|
||||
"""Add nodes to index, batching the insertion to avoid issues.
|
||||
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
|
||||
"""
|
||||
if not self.chroma_client:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
|
||||
max_chunk_size = self.chroma_client.max_batch_size
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadatas.append(
|
||||
node_to_metadata_dict(
|
||||
node, remove_text=True, flat_metadata=self.flat_metadata
|
||||
)
|
||||
)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
|
||||
return all_ids
|
@@ -4,9 +4,9 @@ import chromadb
|
||||
from injector import inject, singleton
|
||||
from llama_index import VectorStoreIndex
|
||||
from llama_index.indices.vector_store import VectorIndexRetriever
|
||||
from llama_index.vector_stores import ChromaVectorStore
|
||||
from llama_index.vector_stores.types import VectorStore
|
||||
|
||||
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
|
||||
@@ -36,14 +36,16 @@ class VectorStoreComponent:
|
||||
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
db = chromadb.PersistentClient(
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=str((local_data_path / "chroma_db").absolute())
|
||||
)
|
||||
chroma_collection = db.get_or_create_collection(
|
||||
chroma_collection = chroma_client.get_or_create_collection(
|
||||
"make_this_parameterizable_per_api_call"
|
||||
) # TODO
|
||||
|
||||
self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
self.vector_store = BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_retriever(
|
||||
|
Reference in New Issue
Block a user