mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 08:32:32 +00:00
chroma[patch]: add get_by_ids
and fix bug (#28516)
- Run standard integration tests in Chroma - Add `get_by_ids` method - Fix bug in `add_texts`: if a list of `ids` is passed but any of them are None, Chroma will raise an exception. Here we assign a uuid.
This commit is contained in:
parent
12d74d5bef
commit
eec55c2550
@ -16,6 +16,7 @@ from typing import (
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
@ -517,6 +518,11 @@ class Chroma(VectorStore):
|
||||
"""
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid4()) for _ in texts]
|
||||
else:
|
||||
# Assign strings to any null IDs
|
||||
for idx, _id in enumerate(ids):
|
||||
if _id is None:
|
||||
ids[idx] = str(uuid.uuid4())
|
||||
embeddings = None
|
||||
texts = list(texts)
|
||||
if self._embedding_function is not None:
|
||||
@ -1028,6 +1034,38 @@ class Chroma(VectorStore):
|
||||
|
||||
return self._collection.get(**kwargs) # type: ignore
|
||||
|
||||
def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
|
||||
"""Get documents by their IDs.
|
||||
|
||||
The returned documents are expected to have the ID field set to the ID of the
|
||||
document in the vector store.
|
||||
|
||||
Fewer documents may be returned than requested if some IDs are not found or
|
||||
if there are duplicated IDs.
|
||||
|
||||
Users should not assume that the order of the returned documents matches
|
||||
the order of the input IDs. Instead, users should rely on the ID field of the
|
||||
returned documents.
|
||||
|
||||
This method should **NOT** raise exceptions if no documents are found for
|
||||
some IDs.
|
||||
|
||||
Args:
|
||||
ids: List of ids to retrieve.
|
||||
|
||||
Returns:
|
||||
List of Documents.
|
||||
|
||||
.. versionadded:: 0.2.1
|
||||
"""
|
||||
results = self.get(ids=list(ids))
|
||||
return [
|
||||
Document(page_content=doc, metadata=meta, id=doc_id)
|
||||
for doc, meta, doc_id in zip(
|
||||
results["documents"], results["metadatas"], results["ids"]
|
||||
)
|
||||
]
|
||||
|
||||
def update_document(self, document_id: str, document: Document) -> None:
|
||||
"""Update a document in the collection.
|
||||
|
||||
|
23
libs/partners/chroma/poetry.lock
generated
23
libs/partners/chroma/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
@ -955,6 +955,25 @@ typing-extensions = ">=4.7"
|
||||
type = "directory"
|
||||
url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-tests"
|
||||
version = "0.3.4"
|
||||
description = "Standard tests for LangChain implementations"
|
||||
optional = false
|
||||
python-versions = ">=3.9,<4.0"
|
||||
files = []
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
httpx = "^0.27.0"
|
||||
langchain-core = "^0.3.19"
|
||||
pytest = ">=7,<9"
|
||||
syrupy = "^4"
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
url = "../../standard-tests"
|
||||
|
||||
[[package]]
|
||||
name = "langsmith"
|
||||
version = "0.1.139"
|
||||
@ -2805,4 +2824,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<4"
|
||||
content-hash = "4e3e3152fdc954723a33ffc5cc5c42b763e5aee74f39df8b54f16cdb753b2d13"
|
||||
content-hash = "2d6bc4b9a18a322c326c3f7d5786c4b196a997458e6d2ca4043cb6b7a4a123b3"
|
||||
|
@ -90,6 +90,10 @@ python = ">=3.9"
|
||||
version = ">=0.1.40,<0.3"
|
||||
python = "<3.9"
|
||||
|
||||
[[tool.poetry.group.test.dependencies.langchain-tests]]
|
||||
path = "../../standard-tests"
|
||||
develop = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
|
@ -0,0 +1,37 @@
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings.fake import DeterministicFakeEmbedding
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from langchain_tests.integration_tests.vectorstores import (
|
||||
AsyncReadWriteTestSuite,
|
||||
ReadWriteTestSuite,
|
||||
)
|
||||
|
||||
from langchain_chroma import Chroma
|
||||
|
||||
|
||||
class TestSync(ReadWriteTestSuite):
|
||||
@pytest.fixture()
|
||||
def vectorstore(self) -> Generator[VectorStore, None, None]: # type: ignore
|
||||
"""Get an empty vectorstore for unit tests."""
|
||||
embeddings = DeterministicFakeEmbedding(size=10)
|
||||
store = Chroma(embedding_function=embeddings)
|
||||
try:
|
||||
yield store
|
||||
finally:
|
||||
store.delete_collection()
|
||||
pass
|
||||
|
||||
|
||||
class TestAsync(AsyncReadWriteTestSuite):
|
||||
@pytest.fixture()
|
||||
async def vectorstore(self) -> AsyncGenerator[VectorStore, None]: # type: ignore
|
||||
"""Get an empty vectorstore for unit tests."""
|
||||
embeddings = DeterministicFakeEmbedding(size=10)
|
||||
store = Chroma(embedding_function=embeddings)
|
||||
try:
|
||||
yield store
|
||||
finally:
|
||||
store.delete_collection()
|
||||
pass
|
Loading…
Reference in New Issue
Block a user