chroma[patch]: add get_by_ids and fix bug (#28516)

- Run standard integration tests in Chroma
- Add `get_by_ids` method
- Fix bug in `add_texts`: if a list of `ids` is passed but any of them
are None, Chroma will raise an exception. Here we assign a uuid.
This commit is contained in:
ccurme 2024-12-04 14:00:36 -05:00 committed by GitHub
parent 12d74d5bef
commit eec55c2550
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 100 additions and 2 deletions

View File

@ -16,6 +16,7 @@ from typing import (
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
@ -517,6 +518,11 @@ class Chroma(VectorStore):
"""
if ids is None:
ids = [str(uuid.uuid4()) for _ in texts]
else:
# Assign strings to any null IDs
for idx, _id in enumerate(ids):
if _id is None:
ids[idx] = str(uuid.uuid4())
embeddings = None
texts = list(texts)
if self._embedding_function is not None:
@ -1028,6 +1034,38 @@ class Chroma(VectorStore):
return self._collection.get(**kwargs) # type: ignore
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
"""Get documents by their IDs.
The returned documents are expected to have the ID field set to the ID of the
document in the vector store.
Fewer documents may be returned than requested if some IDs are not found or
if there are duplicated IDs.
Users should not assume that the order of the returned documents matches
the order of the input IDs. Instead, users should rely on the ID field of the
returned documents.
This method should **NOT** raise exceptions if no documents are found for
some IDs.
Args:
ids: List of ids to retrieve.
Returns:
List of Documents.
.. versionadded:: 0.2.1
"""
results = self.get(ids=list(ids))
return [
Document(page_content=doc, metadata=meta, id=doc_id)
for doc, meta, doc_id in zip(
results["documents"], results["metadatas"], results["ids"]
)
]
def update_document(self, document_id: str, document: Document) -> None:
"""Update a document in the collection.

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -955,6 +955,25 @@ typing-extensions = ">=4.7"
type = "directory"
url = "../../core"
[[package]]
name = "langchain-tests"
version = "0.3.4"
description = "Standard tests for LangChain implementations"
optional = false
python-versions = ">=3.9,<4.0"
files = []
develop = true
[package.dependencies]
httpx = "^0.27.0"
langchain-core = "^0.3.19"
pytest = ">=7,<9"
syrupy = "^4"
[package.source]
type = "directory"
url = "../../standard-tests"
[[package]]
name = "langsmith"
version = "0.1.139"
@ -2805,4 +2824,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<4"
content-hash = "4e3e3152fdc954723a33ffc5cc5c42b763e5aee74f39df8b54f16cdb753b2d13"
content-hash = "2d6bc4b9a18a322c326c3f7d5786c4b196a997458e6d2ca4043cb6b7a4a123b3"

View File

@ -90,6 +90,10 @@ python = ">=3.9"
version = ">=0.1.40,<0.3"
python = "<3.9"
[[tool.poetry.group.test.dependencies.langchain-tests]]
path = "../../standard-tests"
develop = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.0"

View File

@ -0,0 +1,37 @@
from typing import AsyncGenerator, Generator
import pytest
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
from langchain_core.vectorstores import VectorStore
from langchain_tests.integration_tests.vectorstores import (
AsyncReadWriteTestSuite,
ReadWriteTestSuite,
)
from langchain_chroma import Chroma
class TestSync(ReadWriteTestSuite):
@pytest.fixture()
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
embeddings = DeterministicFakeEmbedding(size=10)
store = Chroma(embedding_function=embeddings)
try:
yield store
finally:
store.delete_collection()
pass
class TestAsync(AsyncReadWriteTestSuite):
@pytest.fixture()
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
"""Get an empty vectorstore for unit tests."""
embeddings = DeterministicFakeEmbedding(size=10)
store = Chroma(embedding_function=embeddings)
try:
yield store
finally:
store.delete_collection()
pass