feat(chroma): Add Chroma Cloud support (#32125)

* Adding support for more Chroma client options (`HttpClient` and
`CloundClient`). This includes adding arguments necessary for
instantiating these clients.
* Adding support for Chroma's new persisted collection configuration (we
moved index configuration into this new construct).
* Delegate `Settings` configuration to Chroma's client constructors.
This commit is contained in:
itaismith
2025-07-22 12:14:15 -07:00
committed by GitHub
parent 3fc27e7a95
commit 09769373b3
2 changed files with 467 additions and 161 deletions

View File

@@ -20,13 +20,15 @@ from typing import (
import chromadb
import chromadb.config
import numpy as np
from chromadb import Settings
from chromadb.api import CreateCollectionConfiguration
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import xor_args
from langchain_core.vectorstores import VectorStore
if TYPE_CHECKING:
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
from chromadb.api.types import Where, WhereDocument
logger = logging.getLogger()
DEFAULT_K = 4 # Number of Documents to return.
@@ -167,6 +169,20 @@ class Chroma(VectorStore):
Chroma client settings.
persist_directory: Optional[str]
Directory to persist the collection.
host: Optional[str]
Hostname of a deployed Chroma server.
port: Optional[int]
Connection port for a deployed Chroma server. Default is 8000.
ssl: Optional[bool]
Whether to establish an SSL connection with a deployed Chroma server. Default is False.
headers: Optional[dict[str, str]]
HTTP headers to send to a deployed Chroma server.
chroma_cloud_api_key: Optional[str]
Chroma Cloud API key.
tenant: Optional[str]
Tenant ID. Required for Chroma Cloud connections. Default is 'default_tenant' for local Chroma servers.
database: Optional[str]
Database name. Required for Chroma Cloud connections. Default is 'default_database'.
Instantiate:
.. code-block:: python
@@ -284,11 +300,20 @@ class Chroma(VectorStore):
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
embedding_function: Optional[Embeddings] = None,
persist_directory: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
headers: Optional[dict[str, str]] = None,
chroma_cloud_api_key: Optional[str] = None,
tenant: Optional[str] = None,
database: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
collection_metadata: Optional[dict] = None,
collection_configuration: Optional[CreateCollectionConfiguration] = None,
client: Optional[chromadb.ClientAPI] = None,
relevance_score_fn: Optional[Callable[[float], float]] = None,
create_collection_if_not_exists: Optional[bool] = True, # noqa: FBT001, FBT002
*,
ssl: bool = False,
) -> None:
"""Initialize with a Chroma client.
@@ -296,8 +321,20 @@ class Chroma(VectorStore):
collection_name: Name of the collection to create.
embedding_function: Embedding class object. Used to embed texts.
persist_directory: Directory to persist the collection.
host: Hostname of a deployed Chroma server.
port: Connection port for a deployed Chroma server. Default is 8000.
ssl: Whether to establish an SSL connection with a deployed Chroma server.
Default is False.
headers: HTTP headers to send to a deployed Chroma server.
chroma_cloud_api_key: Chroma Cloud API key.
tenant: Tenant ID. Required for Chroma Cloud connections.
Default is 'default_tenant' for local Chroma servers.
database: Database name. Required for Chroma Cloud connections.
Default is 'default_database'.
client_settings: Chroma client settings
collection_metadata: Collection configurations.
collection_configuration: Index configuration for the collection.
Defaults to None.
client: Chroma client. Documentation:
https://docs.trychroma.com/reference/python/client
relevance_score_fn: Function to calculate relevance score from distance.
@@ -305,37 +342,73 @@ class Chroma(VectorStore):
create_collection_if_not_exists: Whether to create collection
if it doesn't exist. Defaults to True.
"""
if client is not None:
self._client_settings = client_settings
self._client = client
self._persist_directory = persist_directory
else:
if client_settings:
# If client_settings is provided with persist_directory specified,
# then it is "in-memory and persisting to disk" mode.
client_settings.persist_directory = (
persist_directory or client_settings.persist_directory
)
client_settings.is_persistent = (
client_settings.persist_directory is not None
)
_tenant = tenant or chromadb.DEFAULT_TENANT
_database = database or chromadb.DEFAULT_DATABASE
_settings = client_settings or Settings()
_client_settings = client_settings
elif persist_directory:
_client_settings = chromadb.config.Settings(is_persistent=True)
_client_settings.persist_directory = persist_directory
else:
_client_settings = chromadb.config.Settings()
self._client_settings = _client_settings
self._client = chromadb.Client(_client_settings)
self._persist_directory = (
_client_settings.persist_directory or persist_directory
client_args = {
"persist_directory": persist_directory,
"host": host,
"chroma_cloud_api_key": chroma_cloud_api_key,
}
if sum(arg is not None for arg in client_args.values()) > 1:
provided = [
name for name, value in client_args.items() if value is not None
]
msg = (
f"Only one of 'persist_directory', 'host' and 'chroma_cloud_api_key' "
f"is allowed, but got {','.join(provided)}"
)
raise ValueError(msg)
if client is not None:
self._client = client
# PersistentClient
elif persist_directory is not None:
self._client = chromadb.PersistentClient(
path=persist_directory,
settings=_settings,
tenant=_tenant,
database=_database,
)
# HttpClient
elif host is not None:
_port = port or 8000
self._client = chromadb.HttpClient(
host=host,
port=_port,
ssl=ssl,
headers=headers,
settings=_settings,
tenant=_tenant,
database=_database,
)
# CloudClient
elif chroma_cloud_api_key is not None:
if not tenant or not database:
msg = (
"Must provide tenant and database values to connect to Chroma Cloud"
)
raise ValueError(msg)
self._client = chromadb.CloudClient(
tenant=tenant,
database=database,
api_key=chroma_cloud_api_key,
settings=_settings,
)
else:
self._client = chromadb.Client(settings=_settings)
self._embedding_function = embedding_function
self._chroma_collection: Optional[chromadb.Collection] = None
self._collection_name = collection_name
self._collection_metadata = collection_metadata
self._collection_configuration = collection_configuration
if create_collection_if_not_exists:
self.__ensure_collection()
else:
@@ -348,6 +421,7 @@ class Chroma(VectorStore):
name=self._collection_name,
embedding_function=None,
metadata=self._collection_metadata,
configuration=self._collection_configuration,
)
@property
@@ -405,7 +479,8 @@ class Chroma(VectorStore):
**kwargs,
)
def encode_image(self, uri: str) -> str:
@staticmethod
def encode_image(uri: str) -> str:
"""Get base64 string from image URI."""
with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
@@ -415,7 +490,6 @@ class Chroma(VectorStore):
uris: list[str],
metadatas: Optional[list[dict]] = None,
ids: Optional[list[str]] = None,
**kwargs: Any,
) -> list[str]:
"""Run more images through the embeddings and add to the vectorstore.
@@ -424,7 +498,6 @@ class Chroma(VectorStore):
metadatas: Optional list of metadatas.
When querying, you can filter on this metadata.
ids: Optional list of IDs. (Items without IDs will be assigned UUIDs)
kwargs: Additional keyword arguments to pass.
Returns:
List of IDs of the added images.
@@ -635,7 +708,7 @@ class Chroma(VectorStore):
List of Documents most similar to the query vector.
"""
results = self.__query_collection(
query_embeddings=embedding,
query_embeddings=[embedding],
n_results=k,
where=filter,
where_document=where_document,
@@ -658,7 +731,7 @@ class Chroma(VectorStore):
k: Number of Documents to return. Defaults to 4.
filter: Filter by metadata. Defaults to None.
where_document: dict used to filter by the documents.
E.g. {"$contains": "hello"}}.
E.g. {"$contains": "hello"}.
kwargs: Additional keyword arguments to pass to Chroma collection query.
Returns:
@@ -666,7 +739,7 @@ class Chroma(VectorStore):
in float for each. Lower score represents more similarity.
"""
results = self.__query_collection(
query_embeddings=embedding,
query_embeddings=[embedding],
n_results=k,
where=filter,
where_document=where_document,
@@ -765,10 +838,10 @@ class Chroma(VectorStore):
"""Select the relevance score function based on collections distance metric.
The most similar documents will have the lowest relevance score. Default
relevance score function is euclidean distance. Distance metric must be
provided in `collection_metadata` during initialization of Chroma object.
Example: collection_metadata={"hnsw:space": "cosine"}. Available distance
metrics are: 'cosine', 'l2' and 'ip'.
relevance score function is Euclidean distance. Distance metric must be
provided in `collection_configuration` during initialization of Chroma object.
Example: collection_configuration={"hnsw": {"space": "cosine"}}.
Available distance metrics are: 'cosine', 'l2' and 'ip'.
Returns:
The relevance score function.
@@ -779,12 +852,15 @@ class Chroma(VectorStore):
if self.override_relevance_score_fn:
return self.override_relevance_score_fn
distance = "l2"
distance_key = "hnsw:space"
metadata = self._collection.metadata
hnsw_config = self._collection.configuration.get("hnsw")
hnsw_distance: Optional[str] = hnsw_config.get("space") if hnsw_config else None
if metadata and distance_key in metadata:
distance = metadata[distance_key]
spann_config = self._collection.configuration.get("spann")
spann_distance: Optional[str] = (
spann_config.get("space") if spann_config else None
)
distance = hnsw_distance or spann_distance
if distance == "cosine":
return self._cosine_relevance_score_fn
@@ -826,24 +902,22 @@ class Chroma(VectorStore):
Raises:
ValueError: If the embedding function does not support image embeddings.
"""
if self._embedding_function is None or not hasattr(
self._embedding_function,
"embed_image",
if self._embedding_function is not None and hasattr(
self._embedding_function, "embed_image"
):
msg = "The embedding function must support image embedding."
raise ValueError(msg)
# Obtain image embedding
# Assuming embed_image returns a single embedding
image_embedding = self._embedding_function.embed_image(uris=[uri])
# Obtain image embedding
# Assuming embed_image returns a single embedding
image_embedding = self._embedding_function.embed_image(uris=[uri])
# Perform similarity search based on the obtained embedding
return self.similarity_search_by_vector(
embedding=image_embedding,
k=k,
filter=filter,
**kwargs,
)
# Perform similarity search based on the obtained embedding
return self.similarity_search_by_vector(
embedding=image_embedding,
k=k,
filter=filter,
**kwargs,
)
msg = "The embedding function must support image embedding."
raise ValueError(msg)
def similarity_search_by_image_with_relevance_score(
self,
@@ -870,24 +944,22 @@ class Chroma(VectorStore):
Raises:
ValueError: If the embedding function does not support image embeddings.
"""
if self._embedding_function is None or not hasattr(
self._embedding_function,
"embed_image",
if self._embedding_function is not None and hasattr(
self._embedding_function, "embed_image"
):
msg = "The embedding function must support image embedding."
raise ValueError(msg)
# Obtain image embedding
# Assuming embed_image returns a single embedding
image_embedding = self._embedding_function.embed_image(uris=[uri])
# Obtain image embedding
# Assuming embed_image returns a single embedding
image_embedding = self._embedding_function.embed_image(uris=[uri])
# Perform similarity search based on the obtained embedding
return self.similarity_search_by_vector_with_relevance_scores(
embedding=image_embedding,
k=k,
filter=filter,
**kwargs,
)
# Perform similarity search based on the obtained embedding
return self.similarity_search_by_vector_with_relevance_scores(
embedding=image_embedding,
k=k,
filter=filter,
**kwargs,
)
msg = "The embedding function must support image embedding."
raise ValueError(msg)
def max_marginal_relevance_search_by_vector(
self,
@@ -922,7 +994,7 @@ class Chroma(VectorStore):
List of Documents selected by maximal marginal relevance.
"""
results = self.__query_collection(
query_embeddings=embedding,
query_embeddings=[embedding],
n_results=fetch_k,
where=filter,
where_document=where_document,
@@ -1005,7 +1077,7 @@ class Chroma(VectorStore):
def get(
self,
ids: Optional[OneOrMany[ID]] = None,
ids: Optional[Union[str, list[str]]] = None,
where: Optional[Where] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
@@ -1066,7 +1138,7 @@ class Chroma(VectorStore):
Returns:
List of Documents.
.. versionadded:: 0.2.1
... versionadded:: 0.2.1
"""
results = self.get(ids=list(ids))
return [
@@ -1107,16 +1179,16 @@ class Chroma(VectorStore):
embeddings = self._embedding_function.embed_documents(text)
if hasattr(
self._collection._client, # noqa: SLF001
self._client,
"get_max_batch_size",
) or hasattr( # for Chroma 0.5.1 and above
self._collection._client, # noqa: SLF001
self._client,
"max_batch_size",
): # for Chroma 0.4.10 and above
from chromadb.utils.batch_utils import create_batches
for batch in create_batches(
api=self._collection._client, # noqa: SLF001
api=self._client,
ids=ids,
metadatas=metadata, # type: ignore[arg-type]
documents=text,
@@ -1145,9 +1217,18 @@ class Chroma(VectorStore):
ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
headers: Optional[dict[str, str]] = None,
chroma_cloud_api_key: Optional[str] = None,
tenant: Optional[str] = None,
database: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None,
collection_metadata: Optional[dict] = None,
collection_configuration: Optional[CreateCollectionConfiguration] = None,
*,
ssl: bool = False,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a raw documents.
@@ -1159,14 +1240,26 @@ class Chroma(VectorStore):
texts: List of texts to add to the collection.
collection_name: Name of the collection to create.
persist_directory: Directory to persist the collection.
host: Hostname of a deployed Chroma server.
port: Connection port for a deployed Chroma server.
Default is 8000.
ssl: Whether to establish an SSL connection with a deployed Chroma server.
Default is False.
headers: HTTP headers to send to a deployed Chroma server.
chroma_cloud_api_key: Chroma Cloud API key.
tenant: Tenant ID. Required for Chroma Cloud connections.
Default is 'default_tenant' for local Chroma servers.
database: Database name. Required for Chroma Cloud connections.
Default is 'default_database'.
embedding: Embedding function. Defaults to None.
metadatas: List of metadatas. Defaults to None.
ids: List of document IDs. Defaults to None.
client_settings: Chroma client settings.
client: Chroma client. Documentation:
https://docs.trychroma.com/reference/python/client
collection_metadata: Collection configurations.
Defaults to None.
collection_metadata: Collection configurations. Defaults to None.
collection_configuration: Index configuration for the collection.
Defaults to None.
kwargs: Additional keyword arguments to initialize a Chroma client.
Returns:
@@ -1176,9 +1269,17 @@ class Chroma(VectorStore):
collection_name=collection_name,
embedding_function=embedding,
persist_directory=persist_directory,
host=host,
port=port,
ssl=ssl,
headers=headers,
chroma_cloud_api_key=chroma_cloud_api_key,
tenant=tenant,
database=database,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
collection_configuration=collection_configuration,
**kwargs,
)
if ids is None:
@@ -1217,9 +1318,18 @@ class Chroma(VectorStore):
ids: Optional[list[str]] = None,
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
persist_directory: Optional[str] = None,
host: Optional[str] = None,
port: Optional[int] = None,
headers: Optional[dict[str, str]] = None,
chroma_cloud_api_key: Optional[str] = None,
tenant: Optional[str] = None,
database: Optional[str] = None,
client_settings: Optional[chromadb.config.Settings] = None,
client: Optional[chromadb.ClientAPI] = None, # Add this line
collection_metadata: Optional[dict] = None,
collection_configuration: Optional[CreateCollectionConfiguration] = None,
*,
ssl: bool = False,
**kwargs: Any,
) -> Chroma:
"""Create a Chroma vectorstore from a list of documents.
@@ -1230,14 +1340,25 @@ class Chroma(VectorStore):
Args:
collection_name: Name of the collection to create.
persist_directory: Directory to persist the collection.
host: Hostname of a deployed Chroma server.
port: Connection port for a deployed Chroma server. Default is 8000.
ssl: Whether to establish an SSL connection with a deployed Chroma server.
Default is False.
headers: HTTP headers to send to a deployed Chroma server.
chroma_cloud_api_key: Chroma Cloud API key.
tenant: Tenant ID. Required for Chroma Cloud connections.
Default is 'default_tenant' for local Chroma servers.
database: Database name. Required for Chroma Cloud connections.
Default is 'default_database'.
ids : List of document IDs. Defaults to None.
documents: List of documents to add to the vectorstore.
embedding: Embedding function. Defaults to None.
client_settings: Chroma client settings.
client: Chroma client. Documentation:
https://docs.trychroma.com/reference/python/client
collection_metadata: Collection configurations.
Defaults to None.
collection_metadata: Collection configurations. Defaults to None.
collection_configuration: Index configuration for the collection.
Defaults to None.
kwargs: Additional keyword arguments to initialize a Chroma client.
Returns:
@@ -1254,9 +1375,17 @@ class Chroma(VectorStore):
ids=ids,
collection_name=collection_name,
persist_directory=persist_directory,
host=host,
port=port,
ssl=ssl,
headers=headers,
chroma_cloud_api_key=chroma_cloud_api_key,
tenant=tenant,
database=database,
client_settings=client_settings,
client=client,
collection_metadata=collection_metadata,
collection_configuration=collection_configuration,
**kwargs,
)