From f5a9bf4e374b2d4c76438cf8a97cccf222ec8e6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Mart=C3=ADnez?= Date: Fri, 20 Oct 2023 18:24:56 +0200 Subject: [PATCH] fix: chromadb max batch size (#1087) --- poetry.lock | 82 +++++------------ .../components/vector_store/batched_chroma.py | 87 +++++++++++++++++++ .../vector_store/vector_store_component.py | 10 ++- pyproject.toml | 4 +- tests/server/ingest/test_ingest_service.py | 27 ++++++ 5 files changed, 142 insertions(+), 68 deletions(-) create mode 100644 private_gpt/components/vector_store/batched_chroma.py create mode 100644 tests/server/ingest/test_ingest_service.py diff --git a/poetry.lock b/poetry.lock index 4ca1626f..f1e3ac22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -261,24 +261,6 @@ files = [ tests = ["pytest (>=3.2.1,!=3.3.0)"] typecheck = ["mypy"] -[[package]] -name = "beautifulsoup4" -version = "4.12.2" -description = "Screen-scraping library" -optional = false -python-versions = ">=3.6.0" -files = [ - {file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"}, - {file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"}, -] - -[package.dependencies] -soupsieve = ">1.2" - -[package.extras] -html5lib = ["html5lib"] -lxml = ["lxml"] - [[package]] name = "black" version = "22.12.0" @@ -643,10 +625,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}, - {version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""}, -] +numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""} [package.extras] bokeh = ["bokeh", "selenium"] @@ -736,13 +715,13 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"] [[package]] name = "dataclasses-json" -version = "0.6.1" +version = "0.5.14" description = "Easily serialize dataclasses to and from JSON." optional = false -python-versions = ">=3.7,<4.0" +python-versions = ">=3.7,<3.13" files = [ - {file = "dataclasses_json-0.6.1-py3-none-any.whl", hash = "sha256:1bd8418a61fe3d588bb0079214d7fb71d44937da40742b787256fd53b26b6c80"}, - {file = "dataclasses_json-0.6.1.tar.gz", hash = "sha256:a53c220c35134ce08211a1057fd0e5bf76dc5331627c6b241cacbc570a89faae"}, + {file = "dataclasses_json-0.5.14-py3-none-any.whl", hash = "sha256:5ec6fed642adb1dbdb4182badb01e0861badfd8fda82e3b67f44b2d1e9d10d21"}, + {file = "dataclasses_json-0.5.14.tar.gz", hash = "sha256:d82896a94c992ffaf689cd1fafc180164e2abdd415b8f94a7f78586af5886236"}, ] [package.dependencies] @@ -1130,7 +1109,7 @@ files = [ {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"}, {file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"}, {file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"}, - {file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"}, + {file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"}, {file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"}, {file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"}, @@ -1140,7 +1119,6 @@ files = [ {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"}, {file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"}, {file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"}, - {file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"}, {file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"}, @@ -1760,28 +1738,27 @@ test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"] [[package]] name = "llama-index" -version = "0.8.35" +version = "0.8.47" description = "Interface between LLMs and your data" optional = false -python-versions = "*" +python-versions = ">=3.8.1,<3.12" files = [ - {file = "llama_index-0.8.35-py3-none-any.whl", hash = "sha256:f2f1670320e75a9643b6dc96662038f777866ed543994d18f71ab54329e295ae"}, - {file = "llama_index-0.8.35.tar.gz", hash = "sha256:a8767be9d36ebd538a37e18b0c7f46bb19d9d7ec490ef7582640f75d5d9b5259"}, + {file = "llama_index-0.8.47-py3-none-any.whl", hash = "sha256:7a0e5154637524fb59b30bd3a349fba2ec6092cf2972276da9dfa38bbe82d721"}, + {file = "llama_index-0.8.47.tar.gz", hash = "sha256:f824e7bcf9b6cf3fb98de59d722695a8db327c83b6b7d30071d931b56c14904f"}, ] [package.dependencies] -beautifulsoup4 = "*" -dataclasses-json = "*" +dataclasses-json = ">=0.5.7,<0.6.0" fsspec = ">=2023.5.0" -langchain = ">=0.0.293" -nest-asyncio = "*" -nltk = "*" +langchain = ">=0.0.303" +nest-asyncio = ">=1.5.8,<2.0.0" +nltk = ">=3.8.1,<4.0.0" numpy = "*" openai = ">=0.26.4" pandas = "*" -sqlalchemy = ">=2.0.15" +SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]} tenacity = ">=8.2.0,<9.0.0" -tiktoken = "*" +tiktoken = ">=0.3.3" typing-extensions = ">=4.5.0" typing-inspect = ">=0.8.0" urllib3 = "<2" @@ -2397,10 +2374,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, -] +numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -3469,11 +3443,6 @@ files = [ {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, - {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, @@ -3690,17 +3659,6 @@ files = [ {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, ] -[[package]] -name = "soupsieve" -version = "2.5" -description = "A modern CSS selector implementation for Beautiful Soup." -optional = false -python-versions = ">=3.8" -files = [ - {file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"}, - {file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"}, -] - [[package]] name = "sqlalchemy" version = "2.0.22" @@ -3760,7 +3718,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] @@ -4738,5 +4696,5 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" -python-versions = ">=3.11,<3.13" -content-hash = "c1fa5accdcd9cd81430839398e16d2596b43b1c314b5d2a8a76aa05bbb83a39c" +python-versions = ">=3.11,<3.12" +content-hash = "56b78ce6a8a6dfbe42b490bcf4ffbf2820f10d5ce70a28c61ee4e357172dab33" diff --git a/private_gpt/components/vector_store/batched_chroma.py b/private_gpt/components/vector_store/batched_chroma.py new file mode 100644 index 00000000..9458e788 --- /dev/null +++ b/private_gpt/components/vector_store/batched_chroma.py @@ -0,0 +1,87 @@ +from typing import Any + +from llama_index.schema import BaseNode, MetadataMode +from llama_index.vector_stores import ChromaVectorStore +from llama_index.vector_stores.chroma import chunk_list +from llama_index.vector_stores.utils import node_to_metadata_dict + + +class BatchedChromaVectorStore(ChromaVectorStore): + """Chroma vector store, batching additions to avoid reaching the max batch limit. + + In this vector store, embeddings are stored within a ChromaDB collection. + + During query time, the index uses ChromaDB to query for the top + k most similar nodes. + + Args: + chroma_client (from chromadb.api.API): + API instance + chroma_collection (chromadb.api.models.Collection.Collection): + ChromaDB collection instance + + """ + + chroma_client: Any | None + + def __init__( + self, + chroma_client: Any, + chroma_collection: Any, + host: str | None = None, + port: str | None = None, + ssl: bool = False, + headers: dict[str, str] | None = None, + collection_kwargs: dict[Any, Any] | None = None, + ) -> None: + super().__init__( + chroma_collection=chroma_collection, + host=host, + port=port, + ssl=ssl, + headers=headers, + collection_kwargs=collection_kwargs or {}, + ) + self.chroma_client = chroma_client + + def add(self, nodes: list[BaseNode]) -> list[str]: + """Add nodes to index, batching the insertion to avoid issues. + + Args: + nodes: List[BaseNode]: list of nodes with embeddings + + """ + if not self.chroma_client: + raise ValueError("Client not initialized") + + if not self._collection: + raise ValueError("Collection not initialized") + + max_chunk_size = self.chroma_client.max_batch_size + node_chunks = chunk_list(nodes, max_chunk_size) + + all_ids = [] + for node_chunk in node_chunks: + embeddings = [] + metadatas = [] + ids = [] + documents = [] + for node in node_chunk: + embeddings.append(node.get_embedding()) + metadatas.append( + node_to_metadata_dict( + node, remove_text=True, flat_metadata=self.flat_metadata + ) + ) + ids.append(node.node_id) + documents.append(node.get_content(metadata_mode=MetadataMode.NONE)) + + self._collection.add( + embeddings=embeddings, + ids=ids, + metadatas=metadatas, + documents=documents, + ) + all_ids.extend(ids) + + return all_ids diff --git a/private_gpt/components/vector_store/vector_store_component.py b/private_gpt/components/vector_store/vector_store_component.py index 02f4a91d..5d173fb5 100644 --- a/private_gpt/components/vector_store/vector_store_component.py +++ b/private_gpt/components/vector_store/vector_store_component.py @@ -4,9 +4,9 @@ import chromadb from injector import inject, singleton from llama_index import VectorStoreIndex from llama_index.indices.vector_store import VectorIndexRetriever -from llama_index.vector_stores import ChromaVectorStore from llama_index.vector_stores.types import VectorStore +from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.paths import local_data_path @@ -36,14 +36,16 @@ class VectorStoreComponent: @inject def __init__(self) -> None: - db = chromadb.PersistentClient( + chroma_client = chromadb.PersistentClient( path=str((local_data_path / "chroma_db").absolute()) ) - chroma_collection = db.get_or_create_collection( + chroma_collection = chroma_client.get_or_create_collection( "make_this_parameterizable_per_api_call" ) # TODO - self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection) + self.vector_store = BatchedChromaVectorStore( + chroma_client=chroma_client, chroma_collection=chroma_collection + ) @staticmethod def get_retriever( diff --git a/pyproject.toml b/pyproject.toml index 4991ec79..05dab2f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Private GPT" authors = ["Zylon "] [tool.poetry.dependencies] -python = ">=3.11,<3.13" +python = ">=3.11,<3.12" fastapi = { extras = ["all"], version = "^0.103.1" } loguru = "^0.7.2" boto3 = "^1.28.56" @@ -13,7 +13,7 @@ injector = "^0.21.0" pyyaml = "^6.0.1" python-multipart = "^0.0.6" pypdf = "^3.16.2" -llama-index = "v0.8.35" +llama-index = "0.8.47" chromadb = "^0.4.13" watchdog = "^3.0.0" transformers = "^4.34.0" diff --git a/tests/server/ingest/test_ingest_service.py b/tests/server/ingest/test_ingest_service.py new file mode 100644 index 00000000..a855d84a --- /dev/null +++ b/tests/server/ingest/test_ingest_service.py @@ -0,0 +1,27 @@ +from unittest.mock import PropertyMock, patch + +from llama_index import Document + +from private_gpt.server.ingest.ingest_service import IngestService +from tests.fixtures.mock_injector import MockInjector + + +def test_save_many_nodes(injector: MockInjector) -> None: + """This is a specific test for a local Chromadb Vector Database setup. + + Extend it when we add support for other vector databases in VectorStoreComponent. + """ + with patch( + "chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock + ) as max_batch_size: + # Make max batch size of Chromadb very small + max_batch_size.return_value = 10 + + ingest_service = injector.get(IngestService) + + documents = [] + for _i in range(100): + documents.append(Document(text="This is a sentence.")) + + ingested_docs = ingest_service._save_docs(documents) + assert len(ingested_docs) == len(documents)