mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-04-27 11:21:34 +00:00
* Extract optional dependencies * Separate local mode into llms-llama-cpp and embeddings-huggingface for clarity * Support Ollama embeddings * Upgrade to llamaindex 0.10.14. Remove legacy use of ServiceContext in ContextChatEngine * Fix vector retriever filters
104 lines
3.2 KiB
Python
104 lines
3.2 KiB
Python
from collections.abc import Generator
|
|
from typing import Any
|
|
|
|
from llama_index.core.schema import BaseNode, MetadataMode
|
|
from llama_index.core.vector_stores.utils import node_to_metadata_dict
|
|
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
|
|
|
|
|
|
def chunk_list(
|
|
lst: list[BaseNode], max_chunk_size: int
|
|
) -> Generator[list[BaseNode], None, None]:
|
|
"""Yield successive max_chunk_size-sized chunks from lst.
|
|
|
|
Args:
|
|
lst (List[BaseNode]): list of nodes with embeddings
|
|
max_chunk_size (int): max chunk size
|
|
|
|
Yields:
|
|
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
|
"""
|
|
for i in range(0, len(lst), max_chunk_size):
|
|
yield lst[i : i + max_chunk_size]
|
|
|
|
|
|
class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
|
|
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
|
|
|
|
In this vector store, embeddings are stored within a ChromaDB collection.
|
|
|
|
During query time, the index uses ChromaDB to query for the top
|
|
k most similar nodes.
|
|
|
|
Args:
|
|
chroma_client (from chromadb.api.API):
|
|
API instance
|
|
chroma_collection (chromadb.api.models.Collection.Collection):
|
|
ChromaDB collection instance
|
|
|
|
"""
|
|
|
|
chroma_client: Any | None
|
|
|
|
def __init__(
|
|
self,
|
|
chroma_client: Any,
|
|
chroma_collection: Any,
|
|
host: str | None = None,
|
|
port: str | None = None,
|
|
ssl: bool = False,
|
|
headers: dict[str, str] | None = None,
|
|
collection_kwargs: dict[Any, Any] | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
chroma_collection=chroma_collection,
|
|
host=host,
|
|
port=port,
|
|
ssl=ssl,
|
|
headers=headers,
|
|
collection_kwargs=collection_kwargs or {},
|
|
)
|
|
self.chroma_client = chroma_client
|
|
|
|
def add(self, nodes: list[BaseNode], **add_kwargs: Any) -> list[str]:
|
|
"""Add nodes to index, batching the insertion to avoid issues.
|
|
|
|
Args:
|
|
nodes: List[BaseNode]: list of nodes with embeddings
|
|
add_kwargs: _
|
|
"""
|
|
if not self.chroma_client:
|
|
raise ValueError("Client not initialized")
|
|
|
|
if not self._collection:
|
|
raise ValueError("Collection not initialized")
|
|
|
|
max_chunk_size = self.chroma_client.max_batch_size
|
|
node_chunks = chunk_list(nodes, max_chunk_size)
|
|
|
|
all_ids = []
|
|
for node_chunk in node_chunks:
|
|
embeddings = []
|
|
metadatas = []
|
|
ids = []
|
|
documents = []
|
|
for node in node_chunk:
|
|
embeddings.append(node.get_embedding())
|
|
metadatas.append(
|
|
node_to_metadata_dict(
|
|
node, remove_text=True, flat_metadata=self.flat_metadata
|
|
)
|
|
)
|
|
ids.append(node.node_id)
|
|
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
|
|
|
self._collection.add(
|
|
embeddings=embeddings,
|
|
ids=ids,
|
|
metadatas=metadatas,
|
|
documents=documents,
|
|
)
|
|
all_ids.extend(ids)
|
|
|
|
return all_ids
|