mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-26 23:39:29 +00:00
fix: chromadb max batch size (#1087)
This commit is contained in:
parent
b46c1087e2
commit
f5a9bf4e37
82
poetry.lock
generated
82
poetry.lock
generated
@ -261,24 +261,6 @@ files = [
|
||||
tests = ["pytest (>=3.2.1,!=3.3.0)"]
|
||||
typecheck = ["mypy"]
|
||||
|
||||
[[package]]
|
||||
name = "beautifulsoup4"
|
||||
version = "4.12.2"
|
||||
description = "Screen-scraping library"
|
||||
optional = false
|
||||
python-versions = ">=3.6.0"
|
||||
files = [
|
||||
{file = "beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a"},
|
||||
{file = "beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
soupsieve = ">1.2"
|
||||
|
||||
[package.extras]
|
||||
html5lib = ["html5lib"]
|
||||
lxml = ["lxml"]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "22.12.0"
|
||||
@ -643,10 +625,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""},
|
||||
{version = ">=1.26.0rc1,<2.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
numpy = {version = ">=1.16,<2.0", markers = "python_version <= \"3.11\""}
|
||||
|
||||
[package.extras]
|
||||
bokeh = ["bokeh", "selenium"]
|
||||
@ -736,13 +715,13 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||
|
||||
[[package]]
|
||||
name = "dataclasses-json"
|
||||
version = "0.6.1"
|
||||
version = "0.5.14"
|
||||
description = "Easily serialize dataclasses to and from JSON."
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
python-versions = ">=3.7,<3.13"
|
||||
files = [
|
||||
{file = "dataclasses_json-0.6.1-py3-none-any.whl", hash = "sha256:1bd8418a61fe3d588bb0079214d7fb71d44937da40742b787256fd53b26b6c80"},
|
||||
{file = "dataclasses_json-0.6.1.tar.gz", hash = "sha256:a53c220c35134ce08211a1057fd0e5bf76dc5331627c6b241cacbc570a89faae"},
|
||||
{file = "dataclasses_json-0.5.14-py3-none-any.whl", hash = "sha256:5ec6fed642adb1dbdb4182badb01e0861badfd8fda82e3b67f44b2d1e9d10d21"},
|
||||
{file = "dataclasses_json-0.5.14.tar.gz", hash = "sha256:d82896a94c992ffaf689cd1fafc180164e2abdd415b8f94a7f78586af5886236"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@ -1130,7 +1109,7 @@ files = [
|
||||
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"},
|
||||
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"},
|
||||
{file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"},
|
||||
{file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"},
|
||||
@ -1140,7 +1119,6 @@ files = [
|
||||
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"},
|
||||
{file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"},
|
||||
{file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"},
|
||||
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"},
|
||||
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"},
|
||||
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"},
|
||||
@ -1760,28 +1738,27 @@ test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-index"
|
||||
version = "0.8.35"
|
||||
version = "0.8.47"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
python-versions = ">=3.8.1,<3.12"
|
||||
files = [
|
||||
{file = "llama_index-0.8.35-py3-none-any.whl", hash = "sha256:f2f1670320e75a9643b6dc96662038f777866ed543994d18f71ab54329e295ae"},
|
||||
{file = "llama_index-0.8.35.tar.gz", hash = "sha256:a8767be9d36ebd538a37e18b0c7f46bb19d9d7ec490ef7582640f75d5d9b5259"},
|
||||
{file = "llama_index-0.8.47-py3-none-any.whl", hash = "sha256:7a0e5154637524fb59b30bd3a349fba2ec6092cf2972276da9dfa38bbe82d721"},
|
||||
{file = "llama_index-0.8.47.tar.gz", hash = "sha256:f824e7bcf9b6cf3fb98de59d722695a8db327c83b6b7d30071d931b56c14904f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
beautifulsoup4 = "*"
|
||||
dataclasses-json = "*"
|
||||
dataclasses-json = ">=0.5.7,<0.6.0"
|
||||
fsspec = ">=2023.5.0"
|
||||
langchain = ">=0.0.293"
|
||||
nest-asyncio = "*"
|
||||
nltk = "*"
|
||||
langchain = ">=0.0.303"
|
||||
nest-asyncio = ">=1.5.8,<2.0.0"
|
||||
nltk = ">=3.8.1,<4.0.0"
|
||||
numpy = "*"
|
||||
openai = ">=0.26.4"
|
||||
pandas = "*"
|
||||
sqlalchemy = ">=2.0.15"
|
||||
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
|
||||
tenacity = ">=8.2.0,<9.0.0"
|
||||
tiktoken = "*"
|
||||
tiktoken = ">=0.3.3"
|
||||
typing-extensions = ">=4.5.0"
|
||||
typing-inspect = ">=0.8.0"
|
||||
urllib3 = "<2"
|
||||
@ -2397,10 +2374,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = [
|
||||
{version = ">=1.23.2", markers = "python_version == \"3.11\""},
|
||||
{version = ">=1.26.0", markers = "python_version >= \"3.12\""},
|
||||
]
|
||||
numpy = {version = ">=1.23.2", markers = "python_version == \"3.11\""}
|
||||
python-dateutil = ">=2.8.2"
|
||||
pytz = ">=2020.1"
|
||||
tzdata = ">=2022.1"
|
||||
@ -3469,11 +3443,6 @@ files = [
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"},
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"},
|
||||
{file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"},
|
||||
{file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"},
|
||||
{file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"},
|
||||
@ -3690,17 +3659,6 @@ files = [
|
||||
{file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "soupsieve"
|
||||
version = "2.5"
|
||||
description = "A modern CSS selector implementation for Beautiful Soup."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "soupsieve-2.5-py3-none-any.whl", hash = "sha256:eaa337ff55a1579b6549dc679565eac1e3d000563bcb1c8ab0d0fefbc0c2cdc7"},
|
||||
{file = "soupsieve-2.5.tar.gz", hash = "sha256:5663d5a7b3bfaeee0bc4372e7fc48f9cff4940b3eec54a6451cc5299f1097690"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlalchemy"
|
||||
version = "2.0.22"
|
||||
@ -3760,7 +3718,7 @@ files = [
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
@ -4738,5 +4696,5 @@ multidict = ">=4.0"
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.11,<3.13"
|
||||
content-hash = "c1fa5accdcd9cd81430839398e16d2596b43b1c314b5d2a8a76aa05bbb83a39c"
|
||||
python-versions = ">=3.11,<3.12"
|
||||
content-hash = "56b78ce6a8a6dfbe42b490bcf4ffbf2820f10d5ce70a28c61ee4e357172dab33"
|
||||
|
87
private_gpt/components/vector_store/batched_chroma.py
Normal file
87
private_gpt/components/vector_store/batched_chroma.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Any
|
||||
|
||||
from llama_index.schema import BaseNode, MetadataMode
|
||||
from llama_index.vector_stores import ChromaVectorStore
|
||||
from llama_index.vector_stores.chroma import chunk_list
|
||||
from llama_index.vector_stores.utils import node_to_metadata_dict
|
||||
|
||||
|
||||
class BatchedChromaVectorStore(ChromaVectorStore):
|
||||
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
|
||||
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
|
||||
Args:
|
||||
chroma_client (from chromadb.api.API):
|
||||
API instance
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
|
||||
"""
|
||||
|
||||
chroma_client: Any | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_client: Any,
|
||||
chroma_collection: Any,
|
||||
host: str | None = None,
|
||||
port: str | None = None,
|
||||
ssl: bool = False,
|
||||
headers: dict[str, str] | None = None,
|
||||
collection_kwargs: dict[Any, Any] | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
chroma_collection=chroma_collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
self.chroma_client = chroma_client
|
||||
|
||||
def add(self, nodes: list[BaseNode]) -> list[str]:
|
||||
"""Add nodes to index, batching the insertion to avoid issues.
|
||||
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
|
||||
"""
|
||||
if not self.chroma_client:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
|
||||
max_chunk_size = self.chroma_client.max_batch_size
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadatas.append(
|
||||
node_to_metadata_dict(
|
||||
node, remove_text=True, flat_metadata=self.flat_metadata
|
||||
)
|
||||
)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
|
||||
return all_ids
|
@ -4,9 +4,9 @@ import chromadb
|
||||
from injector import inject, singleton
|
||||
from llama_index import VectorStoreIndex
|
||||
from llama_index.indices.vector_store import VectorIndexRetriever
|
||||
from llama_index.vector_stores import ChromaVectorStore
|
||||
from llama_index.vector_stores.types import VectorStore
|
||||
|
||||
from private_gpt.components.vector_store.batched_chroma import BatchedChromaVectorStore
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
|
||||
@ -36,14 +36,16 @@ class VectorStoreComponent:
|
||||
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
db = chromadb.PersistentClient(
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=str((local_data_path / "chroma_db").absolute())
|
||||
)
|
||||
chroma_collection = db.get_or_create_collection(
|
||||
chroma_collection = chroma_client.get_or_create_collection(
|
||||
"make_this_parameterizable_per_api_call"
|
||||
) # TODO
|
||||
|
||||
self.vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
self.vector_store = BatchedChromaVectorStore(
|
||||
chroma_client=chroma_client, chroma_collection=chroma_collection
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_retriever(
|
||||
|
@ -5,7 +5,7 @@ description = "Private GPT"
|
||||
authors = ["Zylon <hi@zylon.ai>"]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.11,<3.13"
|
||||
python = ">=3.11,<3.12"
|
||||
fastapi = { extras = ["all"], version = "^0.103.1" }
|
||||
loguru = "^0.7.2"
|
||||
boto3 = "^1.28.56"
|
||||
@ -13,7 +13,7 @@ injector = "^0.21.0"
|
||||
pyyaml = "^6.0.1"
|
||||
python-multipart = "^0.0.6"
|
||||
pypdf = "^3.16.2"
|
||||
llama-index = "v0.8.35"
|
||||
llama-index = "0.8.47"
|
||||
chromadb = "^0.4.13"
|
||||
watchdog = "^3.0.0"
|
||||
transformers = "^4.34.0"
|
||||
|
27
tests/server/ingest/test_ingest_service.py
Normal file
27
tests/server/ingest/test_ingest_service.py
Normal file
@ -0,0 +1,27 @@
|
||||
from unittest.mock import PropertyMock, patch
|
||||
|
||||
from llama_index import Document
|
||||
|
||||
from private_gpt.server.ingest.ingest_service import IngestService
|
||||
from tests.fixtures.mock_injector import MockInjector
|
||||
|
||||
|
||||
def test_save_many_nodes(injector: MockInjector) -> None:
|
||||
"""This is a specific test for a local Chromadb Vector Database setup.
|
||||
|
||||
Extend it when we add support for other vector databases in VectorStoreComponent.
|
||||
"""
|
||||
with patch(
|
||||
"chromadb.api.segment.SegmentAPI.max_batch_size", new_callable=PropertyMock
|
||||
) as max_batch_size:
|
||||
# Make max batch size of Chromadb very small
|
||||
max_batch_size.return_value = 10
|
||||
|
||||
ingest_service = injector.get(IngestService)
|
||||
|
||||
documents = []
|
||||
for _i in range(100):
|
||||
documents.append(Document(text="This is a sentence."))
|
||||
|
||||
ingested_docs = ingest_service._save_docs(documents)
|
||||
assert len(ingested_docs) == len(documents)
|
Loading…
Reference in New Issue
Block a user