fix: chromadb max batch size (#1087)

This commit is contained in:
Iván Martínez 2023-10-20 18:24:56 +02:00 committed by GitHub
parent b46c1087e2
commit f5a9bf4e37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 142 additions and 68 deletions

82
poetry.lock generated
View File

@ -261,24 +261,6 @@ files = [
tests = ["pytest (>=3.2.1,!=3.3.0)"]
typecheck = ["mypy"]
[[package]]
name = "beautifulsoup4"
version = "4.12.2"
description = "Screen-scraping library"
optional = false
python-versions = ">=3.6.0"
files = [
{file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"},
{file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"},
]
[package.dependencies]
soupsieve = ">1.2"
[package.extras]
html5lib = ["html5lib"]
lxml = ["lxml"]
[[package]]
name = "black"
version = "22.12.0"
@ -643,10 +625,7 @@ files = [
]
[package.dependencies]
numpy = [
{version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""},
{version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""},
]
numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}
[package.extras]
bokeh = ["bokeh", "selenium"]
@ -736,13 +715,13 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"]
[[package]]
name = "dataclasses-json"
version = "0.6.1"
version = "0.5.14"
description = "Easily serialize dataclasses to and from JSON."
optional = false
python-versions = ">=3.7,<4.0"
python-versions = ">=3.7,<3.13"
files = [
{file = "dataclasses_json-0.6.1-py3-none-any.whl", hash = "sha256:1bd8418a61fe3d588bb0079214d7fb71d44937da40742b787256fd53b26b6c80"},
{file = "dataclasses_json-0.6.1.tar.gz", hash = "sha256:a53c220c35134ce08211a1057fd0e5bf76dc5331627c6b241cacbc570a89faae"},
{file = "dataclasses_json-0.5.14-py3-none-any.whl", hash = "sha256:5ec6fed642adb1dbdb4182badb01e0861badfd8fda82e3b67f44b2d1e9d10d21"},
{file = "dataclasses_json-0.5.14.tar.gz", hash = "sha256:d82896a94c992ffaf689cd1fafc180164e2abdd415b8f94a7f78586af5886236"},
]
[package.dependencies]
@ -1130,7 +1109,7 @@ files = [
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"},
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"},
{file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"},
{file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"},
{file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"},
{file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"},
@ -1140,7 +1119,6 @@ files = [
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"},
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"},
{file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"},
{file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"},
@ -1760,28 +1738,27 @@ test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"]
[[package]]
name = "llama-index"
version = "0.8.35"
version = "0.8.47"
description = "Interface between LLMs and your data"
optional = false
python-versions = "*"
python-versions = ">=3.8.1,<3.12"
files = [
{file = "llama_index-0.8.35-py3-none-any.whl", hash = "sha256:f2f1670320e75a9643b6dc96662038f777866ed543994d18f71ab54329e295ae"},
{file = "llama_index-0.8.35.tar.gz", hash = "sha256:a8767be9d36ebd538a37e18b0c7f46bb19d9d7ec490ef7582640f75d5d9b5259"},
{file = "llama_index-0.8.47-py3-none-any.whl", hash = "sha256:7a0e5154637524fb59b30bd3a349fba2ec6092cf2972276da9dfa38bbe82d721"},
{file = "llama_index-0.8.47.tar.gz", hash = "sha256:f824e7bcf9b6cf3fb98de59d722695a8db327c83b6b7d30071d931b56c14904f"},
]
[package.dependencies]
beautifulsoup4 = "*"
dataclasses-json = "*"
dataclasses-json = ">=0.5.7,<0.6.0"
fsspec = ">=2023.5.0"
langchain = ">=0.0.293"
nest-asyncio = "*"
nltk = "*"
langchain = ">=0.0.303"
nest-asyncio = ">=1.5.8,<2.0.0"
nltk = ">=3.8.1,<4.0.0"
numpy = "*"
openai = ">=0.26.4"
pandas = "*"
sqlalchemy = ">=2.0.15"
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
tenacity = ">=8.2.0,<9.0.0"
tiktoken = "*"
tiktoken = ">=0.3.3"
typing-extensions = ">=4.5.0"
typing-inspect = ">=0.8.0"
urllib3 = "<2"
@ -2397,10 +2374,7 @@ files = [
]
[package.dependencies]
numpy = [
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
]
numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""}
python-dateutil = ">=2.8.2"
pytz = ">=2020.1"
tzdata = ">=2022.1"
@ -3469,11 +3443,6 @@ files = [
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"},
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"},
{file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"},
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"},
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"},
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"},
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"},
{file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"},
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"},
{file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"},
@ -3690,17 +3659,6 @@ files = [
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
]
[[package]]
name = "soupsieve"
version = "2.5"
description = "A modern CSS selector implementation for Beautiful Soup."
optional = false
python-versions = ">=3.8"
files = [
{file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"},
{file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"},
]
[[package]]
name = "sqlalchemy"
version = "2.0.22"
@ -3760,7 +3718,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
typing-extensions = ">=4.2.0"
[package.extras]
@ -4738,5 +4696,5 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.11,<3.13"
content-hash = "c1fa5accdcd9cd81430839398e16d2596b43b1c314b5d2a8a76aa05bbb83a39c"
python-versions = ">=3.11,<3.12"
content-hash = "56b78ce6a8a6dfbe42b490bcf4ffbf2820f10d5ce70a28c61ee4e357172dab33"

View File

@ -0,0 +1,87 @@
from typing import Any
from llama_index.schema import BaseNode, MetadataMode
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.chroma import chunk_list
from llama_index.vector_stores.utils import node_to_metadata_dict
class BatchedChromaVectorStore(ChromaVectorStore):
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_client (from chromadb.api.API):
API instance
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
chroma_client: Any | None
def __init__(
self,
chroma_client: Any,
chroma_collection: Any,
host: str | None = None,
port: str | None = None,
ssl: bool = False,
headers: dict[str, str] | None = None,
collection_kwargs: dict[Any, Any] | None = None,
) -> None:
super().__init__(
chroma_collection=chroma_collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_kwargs=collection_kwargs or {},
)
self.chroma_client = chroma_client
def add(self, nodes: list[BaseNode]) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self.chroma_client:
raise ValueError("Client not initialized")
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = self.chroma_client.max_batch_size
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids

View File

@ -4,9 +4,9 @@ import chromadb
from injector import inject, singleton
from llama_index import VectorStoreIndex
from llama_index.indices.vector_store import VectorIndexRetriever
from llama_index.vector_stores import ChromaVectorStore
from llama_index.vector_stores.types import VectorStore
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path
@ -36,14 +36,16 @@ class VectorStoreComponent:
@inject
def __init__(self) -> None:
db = chromadb.PersistentClient(
chroma_client = chromadb.PersistentClient(
path=str((local_data_path / "chroma_db").absolute())
)
chroma_collection = db.get_or_create_collection(
chroma_collection = chroma_client.get_or_create_collection(
"make_this_parameterizable_per_api_call"
) # TODO
self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
self.vector_store = BatchedChromaVectorStore(
chroma_client=chroma_client, chroma_collection=chroma_collection
)
@staticmethod
def get_retriever(

View File

@ -5,7 +5,7 @@ description = "Private GPT"
authors = ["Zylon <hi@zylon.ai>"]
[tool.poetry.dependencies]
python = ">=3.11,<3.13"
python = ">=3.11,<3.12"
fastapi = { extras = ["all"], version = "^0.103.1" }
loguru = "^0.7.2"
boto3 = "^1.28.56"
@ -13,7 +13,7 @@ injector = "^0.21.0"
pyyaml = "^6.0.1"
python-multipart = "^0.0.6"
pypdf = "^3.16.2"
llama-index = "v0.8.35"
llama-index = "0.8.47"
chromadb = "^0.4.13"
watchdog = "^3.0.0"
transformers = "^4.34.0"

View File

@ -0,0 +1,27 @@
from unittest.mock import PropertyMock, patch
from llama_index import Document
from private_gpt.server.ingest.ingest_service import IngestService
from tests.fixtures.mock_injector import MockInjector
def test_save_many_nodes(injector: MockInjector) -> None:
"""This is a specific test for a local Chromadb Vector Database setup.
Extend it when we add support for other vector databases in VectorStoreComponent.
"""
with patch(
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock
) as max_batch_size:
# Make max batch size of Chromadb very small
max_batch_size.return_value = 10
ingest_service = injector.get(IngestService)
documents = []
for _i in range(100):
documents.append(Document(text="This is a sentence."))
ingested_docs = ingest_service._save_docs(documents)
assert len(ingested_docs) == len(documents)