mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-15 09:23:57 +00:00
Specify which data to return from chromadb (#4393)
# Improve the Chroma get() method by adding the optional "include" parameter. The Chroma get() method excludes embeddings by default. You can customize the response by specifying the "include" parameter to selectively retrieve the desired data from the collection. --------- Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
00c6ec8a2d
commit
d126276693
@ -313,9 +313,17 @@ class Chroma(VectorStore):
|
||||
"""Delete the collection."""
|
||||
self._client.delete_collection(self._collection.name)
|
||||
|
||||
def get(self) -> Chroma:
|
||||
"""Gets the collection"""
|
||||
return self._collection.get()
|
||||
def get(self, include: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Gets the collection.
|
||||
|
||||
Args:
|
||||
include (Optional[List[str]]): List of fields to include from db.
|
||||
Defaults to None.
|
||||
"""
|
||||
if include is not None:
|
||||
return self._collection.get(include=include)
|
||||
else:
|
||||
return self._collection.get()
|
||||
|
||||
def persist(self) -> None:
|
||||
"""Persist the collection.
|
||||
|
@ -148,3 +148,15 @@ def test_chroma_mmr_by_vector() -> None:
|
||||
embedded_query = embeddings.embed_query("foo")
|
||||
output = docsearch.max_marginal_relevance_search_by_vector(embedded_query, k=1)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_chroma_with_include_parameter() -> None:
|
||||
"""Test end to end construction and include parameter."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection", texts=texts, embedding=FakeEmbeddings()
|
||||
)
|
||||
output = docsearch.get(include=["embeddings"])
|
||||
assert output["embeddings"] is not None
|
||||
output = docsearch.get()
|
||||
assert output["embeddings"] is None
|
||||
|
Loading…
Reference in New Issue
Block a user