mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
feat(chroma): Add Chroma Cloud support (#32125)
* Adding support for more Chroma client options (`HttpClient` and `CloundClient`). This includes adding arguments necessary for instantiating these clients. * Adding support for Chroma's new persisted collection configuration (we moved index configuration into this new construct). * Delegate `Settings` configuration to Chroma's client constructors.
This commit is contained in:
@@ -20,13 +20,15 @@ from typing import (
|
||||
import chromadb
|
||||
import chromadb.config
|
||||
import numpy as np
|
||||
from chromadb import Settings
|
||||
from chromadb.api import CreateCollectionConfiguration
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import xor_args
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.api.types import ID, OneOrMany, Where, WhereDocument
|
||||
from chromadb.api.types import Where, WhereDocument
|
||||
|
||||
logger = logging.getLogger()
|
||||
DEFAULT_K = 4 # Number of Documents to return.
|
||||
@@ -167,6 +169,20 @@ class Chroma(VectorStore):
|
||||
Chroma client settings.
|
||||
persist_directory: Optional[str]
|
||||
Directory to persist the collection.
|
||||
host: Optional[str]
|
||||
Hostname of a deployed Chroma server.
|
||||
port: Optional[int]
|
||||
Connection port for a deployed Chroma server. Default is 8000.
|
||||
ssl: Optional[bool]
|
||||
Whether to establish an SSL connection with a deployed Chroma server. Default is False.
|
||||
headers: Optional[dict[str, str]]
|
||||
HTTP headers to send to a deployed Chroma server.
|
||||
chroma_cloud_api_key: Optional[str]
|
||||
Chroma Cloud API key.
|
||||
tenant: Optional[str]
|
||||
Tenant ID. Required for Chroma Cloud connections. Default is 'default_tenant' for local Chroma servers.
|
||||
database: Optional[str]
|
||||
Database name. Required for Chroma Cloud connections. Default is 'default_database'.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
@@ -284,11 +300,20 @@ class Chroma(VectorStore):
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
embedding_function: Optional[Embeddings] = None,
|
||||
persist_directory: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
chroma_cloud_api_key: Optional[str] = None,
|
||||
tenant: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
collection_configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
create_collection_if_not_exists: Optional[bool] = True, # noqa: FBT001, FBT002
|
||||
*,
|
||||
ssl: bool = False,
|
||||
) -> None:
|
||||
"""Initialize with a Chroma client.
|
||||
|
||||
@@ -296,8 +321,20 @@ class Chroma(VectorStore):
|
||||
collection_name: Name of the collection to create.
|
||||
embedding_function: Embedding class object. Used to embed texts.
|
||||
persist_directory: Directory to persist the collection.
|
||||
host: Hostname of a deployed Chroma server.
|
||||
port: Connection port for a deployed Chroma server. Default is 8000.
|
||||
ssl: Whether to establish an SSL connection with a deployed Chroma server.
|
||||
Default is False.
|
||||
headers: HTTP headers to send to a deployed Chroma server.
|
||||
chroma_cloud_api_key: Chroma Cloud API key.
|
||||
tenant: Tenant ID. Required for Chroma Cloud connections.
|
||||
Default is 'default_tenant' for local Chroma servers.
|
||||
database: Database name. Required for Chroma Cloud connections.
|
||||
Default is 'default_database'.
|
||||
client_settings: Chroma client settings
|
||||
collection_metadata: Collection configurations.
|
||||
collection_configuration: Index configuration for the collection.
|
||||
Defaults to None.
|
||||
client: Chroma client. Documentation:
|
||||
https://docs.trychroma.com/reference/python/client
|
||||
relevance_score_fn: Function to calculate relevance score from distance.
|
||||
@@ -305,37 +342,73 @@ class Chroma(VectorStore):
|
||||
create_collection_if_not_exists: Whether to create collection
|
||||
if it doesn't exist. Defaults to True.
|
||||
"""
|
||||
if client is not None:
|
||||
self._client_settings = client_settings
|
||||
self._client = client
|
||||
self._persist_directory = persist_directory
|
||||
else:
|
||||
if client_settings:
|
||||
# If client_settings is provided with persist_directory specified,
|
||||
# then it is "in-memory and persisting to disk" mode.
|
||||
client_settings.persist_directory = (
|
||||
persist_directory or client_settings.persist_directory
|
||||
)
|
||||
client_settings.is_persistent = (
|
||||
client_settings.persist_directory is not None
|
||||
)
|
||||
_tenant = tenant or chromadb.DEFAULT_TENANT
|
||||
_database = database or chromadb.DEFAULT_DATABASE
|
||||
_settings = client_settings or Settings()
|
||||
|
||||
_client_settings = client_settings
|
||||
elif persist_directory:
|
||||
_client_settings = chromadb.config.Settings(is_persistent=True)
|
||||
_client_settings.persist_directory = persist_directory
|
||||
else:
|
||||
_client_settings = chromadb.config.Settings()
|
||||
self._client_settings = _client_settings
|
||||
self._client = chromadb.Client(_client_settings)
|
||||
self._persist_directory = (
|
||||
_client_settings.persist_directory or persist_directory
|
||||
client_args = {
|
||||
"persist_directory": persist_directory,
|
||||
"host": host,
|
||||
"chroma_cloud_api_key": chroma_cloud_api_key,
|
||||
}
|
||||
|
||||
if sum(arg is not None for arg in client_args.values()) > 1:
|
||||
provided = [
|
||||
name for name, value in client_args.items() if value is not None
|
||||
]
|
||||
msg = (
|
||||
f"Only one of 'persist_directory', 'host' and 'chroma_cloud_api_key' "
|
||||
f"is allowed, but got {','.join(provided)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if client is not None:
|
||||
self._client = client
|
||||
|
||||
# PersistentClient
|
||||
elif persist_directory is not None:
|
||||
self._client = chromadb.PersistentClient(
|
||||
path=persist_directory,
|
||||
settings=_settings,
|
||||
tenant=_tenant,
|
||||
database=_database,
|
||||
)
|
||||
|
||||
# HttpClient
|
||||
elif host is not None:
|
||||
_port = port or 8000
|
||||
self._client = chromadb.HttpClient(
|
||||
host=host,
|
||||
port=_port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
settings=_settings,
|
||||
tenant=_tenant,
|
||||
database=_database,
|
||||
)
|
||||
|
||||
# CloudClient
|
||||
elif chroma_cloud_api_key is not None:
|
||||
if not tenant or not database:
|
||||
msg = (
|
||||
"Must provide tenant and database values to connect to Chroma Cloud"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
self._client = chromadb.CloudClient(
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
api_key=chroma_cloud_api_key,
|
||||
settings=_settings,
|
||||
)
|
||||
|
||||
else:
|
||||
self._client = chromadb.Client(settings=_settings)
|
||||
|
||||
self._embedding_function = embedding_function
|
||||
self._chroma_collection: Optional[chromadb.Collection] = None
|
||||
self._collection_name = collection_name
|
||||
self._collection_metadata = collection_metadata
|
||||
self._collection_configuration = collection_configuration
|
||||
if create_collection_if_not_exists:
|
||||
self.__ensure_collection()
|
||||
else:
|
||||
@@ -348,6 +421,7 @@ class Chroma(VectorStore):
|
||||
name=self._collection_name,
|
||||
embedding_function=None,
|
||||
metadata=self._collection_metadata,
|
||||
configuration=self._collection_configuration,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -405,7 +479,8 @@ class Chroma(VectorStore):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def encode_image(self, uri: str) -> str:
|
||||
@staticmethod
|
||||
def encode_image(uri: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
with open(uri, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -415,7 +490,6 @@ class Chroma(VectorStore):
|
||||
uris: list[str],
|
||||
metadatas: Optional[list[dict]] = None,
|
||||
ids: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[str]:
|
||||
"""Run more images through the embeddings and add to the vectorstore.
|
||||
|
||||
@@ -424,7 +498,6 @@ class Chroma(VectorStore):
|
||||
metadatas: Optional list of metadatas.
|
||||
When querying, you can filter on this metadata.
|
||||
ids: Optional list of IDs. (Items without IDs will be assigned UUIDs)
|
||||
kwargs: Additional keyword arguments to pass.
|
||||
|
||||
Returns:
|
||||
List of IDs of the added images.
|
||||
@@ -635,7 +708,7 @@ class Chroma(VectorStore):
|
||||
List of Documents most similar to the query vector.
|
||||
"""
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
query_embeddings=[embedding],
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
@@ -658,7 +731,7 @@ class Chroma(VectorStore):
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filter by metadata. Defaults to None.
|
||||
where_document: dict used to filter by the documents.
|
||||
E.g. {"$contains": "hello"}}.
|
||||
E.g. {"$contains": "hello"}.
|
||||
kwargs: Additional keyword arguments to pass to Chroma collection query.
|
||||
|
||||
Returns:
|
||||
@@ -666,7 +739,7 @@ class Chroma(VectorStore):
|
||||
in float for each. Lower score represents more similarity.
|
||||
"""
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
query_embeddings=[embedding],
|
||||
n_results=k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
@@ -765,10 +838,10 @@ class Chroma(VectorStore):
|
||||
"""Select the relevance score function based on collections distance metric.
|
||||
|
||||
The most similar documents will have the lowest relevance score. Default
|
||||
relevance score function is euclidean distance. Distance metric must be
|
||||
provided in `collection_metadata` during initialization of Chroma object.
|
||||
Example: collection_metadata={"hnsw:space": "cosine"}. Available distance
|
||||
metrics are: 'cosine', 'l2' and 'ip'.
|
||||
relevance score function is Euclidean distance. Distance metric must be
|
||||
provided in `collection_configuration` during initialization of Chroma object.
|
||||
Example: collection_configuration={"hnsw": {"space": "cosine"}}.
|
||||
Available distance metrics are: 'cosine', 'l2' and 'ip'.
|
||||
|
||||
Returns:
|
||||
The relevance score function.
|
||||
@@ -779,12 +852,15 @@ class Chroma(VectorStore):
|
||||
if self.override_relevance_score_fn:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
distance = "l2"
|
||||
distance_key = "hnsw:space"
|
||||
metadata = self._collection.metadata
|
||||
hnsw_config = self._collection.configuration.get("hnsw")
|
||||
hnsw_distance: Optional[str] = hnsw_config.get("space") if hnsw_config else None
|
||||
|
||||
if metadata and distance_key in metadata:
|
||||
distance = metadata[distance_key]
|
||||
spann_config = self._collection.configuration.get("spann")
|
||||
spann_distance: Optional[str] = (
|
||||
spann_config.get("space") if spann_config else None
|
||||
)
|
||||
|
||||
distance = hnsw_distance or spann_distance
|
||||
|
||||
if distance == "cosine":
|
||||
return self._cosine_relevance_score_fn
|
||||
@@ -826,24 +902,22 @@ class Chroma(VectorStore):
|
||||
Raises:
|
||||
ValueError: If the embedding function does not support image embeddings.
|
||||
"""
|
||||
if self._embedding_function is None or not hasattr(
|
||||
self._embedding_function,
|
||||
"embed_image",
|
||||
if self._embedding_function is not None and hasattr(
|
||||
self._embedding_function, "embed_image"
|
||||
):
|
||||
msg = "The embedding function must support image embedding."
|
||||
raise ValueError(msg)
|
||||
# Obtain image embedding
|
||||
# Assuming embed_image returns a single embedding
|
||||
image_embedding = self._embedding_function.embed_image(uris=[uri])
|
||||
|
||||
# Obtain image embedding
|
||||
# Assuming embed_image returns a single embedding
|
||||
image_embedding = self._embedding_function.embed_image(uris=[uri])
|
||||
|
||||
# Perform similarity search based on the obtained embedding
|
||||
return self.similarity_search_by_vector(
|
||||
embedding=image_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
# Perform similarity search based on the obtained embedding
|
||||
return self.similarity_search_by_vector(
|
||||
embedding=image_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
msg = "The embedding function must support image embedding."
|
||||
raise ValueError(msg)
|
||||
|
||||
def similarity_search_by_image_with_relevance_score(
|
||||
self,
|
||||
@@ -870,24 +944,22 @@ class Chroma(VectorStore):
|
||||
Raises:
|
||||
ValueError: If the embedding function does not support image embeddings.
|
||||
"""
|
||||
if self._embedding_function is None or not hasattr(
|
||||
self._embedding_function,
|
||||
"embed_image",
|
||||
if self._embedding_function is not None and hasattr(
|
||||
self._embedding_function, "embed_image"
|
||||
):
|
||||
msg = "The embedding function must support image embedding."
|
||||
raise ValueError(msg)
|
||||
# Obtain image embedding
|
||||
# Assuming embed_image returns a single embedding
|
||||
image_embedding = self._embedding_function.embed_image(uris=[uri])
|
||||
|
||||
# Obtain image embedding
|
||||
# Assuming embed_image returns a single embedding
|
||||
image_embedding = self._embedding_function.embed_image(uris=[uri])
|
||||
|
||||
# Perform similarity search based on the obtained embedding
|
||||
return self.similarity_search_by_vector_with_relevance_scores(
|
||||
embedding=image_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
# Perform similarity search based on the obtained embedding
|
||||
return self.similarity_search_by_vector_with_relevance_scores(
|
||||
embedding=image_embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
msg = "The embedding function must support image embedding."
|
||||
raise ValueError(msg)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
@@ -922,7 +994,7 @@ class Chroma(VectorStore):
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
results = self.__query_collection(
|
||||
query_embeddings=embedding,
|
||||
query_embeddings=[embedding],
|
||||
n_results=fetch_k,
|
||||
where=filter,
|
||||
where_document=where_document,
|
||||
@@ -1005,7 +1077,7 @@ class Chroma(VectorStore):
|
||||
|
||||
def get(
|
||||
self,
|
||||
ids: Optional[OneOrMany[ID]] = None,
|
||||
ids: Optional[Union[str, list[str]]] = None,
|
||||
where: Optional[Where] = None,
|
||||
limit: Optional[int] = None,
|
||||
offset: Optional[int] = None,
|
||||
@@ -1066,7 +1138,7 @@ class Chroma(VectorStore):
|
||||
Returns:
|
||||
List of Documents.
|
||||
|
||||
.. versionadded:: 0.2.1
|
||||
... versionadded:: 0.2.1
|
||||
"""
|
||||
results = self.get(ids=list(ids))
|
||||
return [
|
||||
@@ -1107,16 +1179,16 @@ class Chroma(VectorStore):
|
||||
embeddings = self._embedding_function.embed_documents(text)
|
||||
|
||||
if hasattr(
|
||||
self._collection._client, # noqa: SLF001
|
||||
self._client,
|
||||
"get_max_batch_size",
|
||||
) or hasattr( # for Chroma 0.5.1 and above
|
||||
self._collection._client, # noqa: SLF001
|
||||
self._client,
|
||||
"max_batch_size",
|
||||
): # for Chroma 0.4.10 and above
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
for batch in create_batches(
|
||||
api=self._collection._client, # noqa: SLF001
|
||||
api=self._client,
|
||||
ids=ids,
|
||||
metadatas=metadata, # type: ignore[arg-type]
|
||||
documents=text,
|
||||
@@ -1145,9 +1217,18 @@ class Chroma(VectorStore):
|
||||
ids: Optional[list[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
chroma_cloud_api_key: Optional[str] = None,
|
||||
tenant: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None,
|
||||
collection_metadata: Optional[dict] = None,
|
||||
collection_configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
*,
|
||||
ssl: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a raw documents.
|
||||
@@ -1159,14 +1240,26 @@ class Chroma(VectorStore):
|
||||
texts: List of texts to add to the collection.
|
||||
collection_name: Name of the collection to create.
|
||||
persist_directory: Directory to persist the collection.
|
||||
host: Hostname of a deployed Chroma server.
|
||||
port: Connection port for a deployed Chroma server.
|
||||
Default is 8000.
|
||||
ssl: Whether to establish an SSL connection with a deployed Chroma server.
|
||||
Default is False.
|
||||
headers: HTTP headers to send to a deployed Chroma server.
|
||||
chroma_cloud_api_key: Chroma Cloud API key.
|
||||
tenant: Tenant ID. Required for Chroma Cloud connections.
|
||||
Default is 'default_tenant' for local Chroma servers.
|
||||
database: Database name. Required for Chroma Cloud connections.
|
||||
Default is 'default_database'.
|
||||
embedding: Embedding function. Defaults to None.
|
||||
metadatas: List of metadatas. Defaults to None.
|
||||
ids: List of document IDs. Defaults to None.
|
||||
client_settings: Chroma client settings.
|
||||
client: Chroma client. Documentation:
|
||||
https://docs.trychroma.com/reference/python/client
|
||||
collection_metadata: Collection configurations.
|
||||
Defaults to None.
|
||||
collection_metadata: Collection configurations. Defaults to None.
|
||||
collection_configuration: Index configuration for the collection.
|
||||
Defaults to None.
|
||||
kwargs: Additional keyword arguments to initialize a Chroma client.
|
||||
|
||||
Returns:
|
||||
@@ -1176,9 +1269,17 @@ class Chroma(VectorStore):
|
||||
collection_name=collection_name,
|
||||
embedding_function=embedding,
|
||||
persist_directory=persist_directory,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
chroma_cloud_api_key=chroma_cloud_api_key,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
collection_configuration=collection_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
if ids is None:
|
||||
@@ -1217,9 +1318,18 @@ class Chroma(VectorStore):
|
||||
ids: Optional[list[str]] = None,
|
||||
collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
||||
persist_directory: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[int] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
chroma_cloud_api_key: Optional[str] = None,
|
||||
tenant: Optional[str] = None,
|
||||
database: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.ClientAPI] = None, # Add this line
|
||||
collection_metadata: Optional[dict] = None,
|
||||
collection_configuration: Optional[CreateCollectionConfiguration] = None,
|
||||
*,
|
||||
ssl: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a list of documents.
|
||||
@@ -1230,14 +1340,25 @@ class Chroma(VectorStore):
|
||||
Args:
|
||||
collection_name: Name of the collection to create.
|
||||
persist_directory: Directory to persist the collection.
|
||||
host: Hostname of a deployed Chroma server.
|
||||
port: Connection port for a deployed Chroma server. Default is 8000.
|
||||
ssl: Whether to establish an SSL connection with a deployed Chroma server.
|
||||
Default is False.
|
||||
headers: HTTP headers to send to a deployed Chroma server.
|
||||
chroma_cloud_api_key: Chroma Cloud API key.
|
||||
tenant: Tenant ID. Required for Chroma Cloud connections.
|
||||
Default is 'default_tenant' for local Chroma servers.
|
||||
database: Database name. Required for Chroma Cloud connections.
|
||||
Default is 'default_database'.
|
||||
ids : List of document IDs. Defaults to None.
|
||||
documents: List of documents to add to the vectorstore.
|
||||
embedding: Embedding function. Defaults to None.
|
||||
client_settings: Chroma client settings.
|
||||
client: Chroma client. Documentation:
|
||||
https://docs.trychroma.com/reference/python/client
|
||||
collection_metadata: Collection configurations.
|
||||
Defaults to None.
|
||||
collection_metadata: Collection configurations. Defaults to None.
|
||||
collection_configuration: Index configuration for the collection.
|
||||
Defaults to None.
|
||||
kwargs: Additional keyword arguments to initialize a Chroma client.
|
||||
|
||||
Returns:
|
||||
@@ -1254,9 +1375,17 @@ class Chroma(VectorStore):
|
||||
ids=ids,
|
||||
collection_name=collection_name,
|
||||
persist_directory=persist_directory,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
chroma_cloud_api_key=chroma_cloud_api_key,
|
||||
tenant=tenant,
|
||||
database=database,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
collection_configuration=collection_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user