privateGPT/private_gpt/components/vector_store/batched_chroma.py
Iván Martínez 45f05711eb
feat: Upgrade to LlamaIndex to 0.10 (#1663)
* Extract optional dependencies

* Separate local mode into llms-llama-cpp and embeddings-huggingface for clarity

* Support Ollama embeddings

* Upgrade to llamaindex 0.10.14. Remove legacy use of ServiceContext in ContextChatEngine

* Fix vector retriever filters
2024-03-06 17:51:30 +01:00

104 lines
3.2 KiB
Python

from collections.abc import Generator
from typing import Any
from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.core.vector_stores.utils import node_to_metadata_dict
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
def chunk_list(
lst: list[BaseNode], max_chunk_size: int
) -> Generator[list[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_client (from chromadb.api.API):
API instance
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
chroma_client: Any | None
def __init__(
self,
chroma_client: Any,
chroma_collection: Any,
host: str | None = None,
port: str | None = None,
ssl: bool = False,
headers: dict[str, str] | None = None,
collection_kwargs: dict[Any, Any] | None = None,
) -> None:
super().__init__(
chroma_collection=chroma_collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_kwargs=collection_kwargs or {},
)
self.chroma_client = chroma_client
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
add_kwargs: _
"""
if not self.chroma_client:
raise ValueError("Client not initialized")
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = self.chroma_client.max_batch_size
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids