From 09769373b38c9594336f7cadc90766ddf036bd1d Mon Sep 17 00:00:00 2001 From: itaismith <42597089+itaismith@users.noreply.github.com> Date: Tue, 22 Jul 2025 12:14:15 -0700 Subject: [PATCH] feat(chroma): Add Chroma Cloud support (#32125) * Adding support for more Chroma client options (`HttpClient` and `CloundClient`). This includes adding arguments necessary for instantiating these clients. * Adding support for Chroma's new persisted collection configuration (we moved index configuration into this new construct). * Delegate `Settings` configuration to Chroma's client constructors. --- .../integrations/vectorstores/chroma.ipynb | 335 +++++++++++++----- .../chroma/langchain_chroma/vectorstores.py | 293 ++++++++++----- 2 files changed, 467 insertions(+), 161 deletions(-) diff --git a/docs/docs/integrations/vectorstores/chroma.ipynb b/docs/docs/integrations/vectorstores/chroma.ipynb index e77a47f788c..9a711bc2f29 100644 --- a/docs/docs/integrations/vectorstores/chroma.ipynb +++ b/docs/docs/integrations/vectorstores/chroma.ipynb @@ -11,6 +11,13 @@ "\n", ">[Chroma](https://docs.trychroma.com/getting-started) is a AI-native open-source vector database focused on developer productivity and happiness. Chroma is licensed under Apache 2.0. View the full docs of `Chroma` at [this page](https://docs.trychroma.com/reference/py-collection), and find the API reference for the LangChain integration at [this page](https://python.langchain.com/api_reference/chroma/vectorstores/langchain_chroma.vectorstores.Chroma.html).\n", "\n", + ":::info Chroma Cloud\n", + "\n", + "Chroma Cloud powers serverless vector and full-text search. It's extremely fast, cost-effective, scalable and painless. Create a DB and try it out in under 30 seconds with $5 of free credits.\n", + "\n", + "[Get started with Chroma Cloud](https://trychroma.com/signup)\n", + ":::\n", + "\n", "## Setup\n", "\n", "To access `Chroma` vector stores you'll need to install the `langchain-chroma` integration package." @@ -33,7 +40,15 @@ "source": [ "### Credentials\n", "\n", - "You can use the `Chroma` vector store without any credentials, simply installing the package above is enough!" + "You can use the `Chroma` vector store without any credentials, simply installing the package above is enough!\n", + "\n", + "If you are a [Chroma Cloud](https://trychroma.com/signup) user, set your `CHROMA_TENANT`, `CHROMA_DATABASE`, and `CHROMA_API_KEY` environment variables.\n", + "\n", + "When you install the `chromadb` package you also get access to the Chroma CLI, which can set these for you. First, [login](https://docs.trychroma.com/docs/cli/login) via the CLI, and then use the [`connect` command](https://docs.trychroma.com/docs/cli/db):\n", + "\n", + "```bash\n", + "chroma db connect [db_name] --env-file\n", + "```" ] }, { @@ -73,7 +88,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "id": "d3ed0a9a", "metadata": {}, "outputs": [], @@ -85,9 +100,19 @@ "embeddings = OpenAIEmbeddings(model=\"text-embedding-3-large\")" ] }, + { + "cell_type": "markdown", + "id": "c6a43e25-227c-4e89-909f-3654fe2710fc", + "metadata": {}, + "source": [ + "#### Running Locally (In-Memory)\n", + "\n", + "You can get a Chroma server running in memory by simply instantiating a `Chroma` instance with a collection name and your embeddings provider:" + ] + }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "id": "3ea11a7b", "metadata": {}, "outputs": [], @@ -97,7 +122,104 @@ "vector_store = Chroma(\n", " collection_name=\"example_collection\",\n", " embedding_function=embeddings,\n", - " persist_directory=\"./chroma_langchain_db\", # Where to save data locally, remove if not necessary\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "92d04cda-e8cc-48aa-9680-470304e3ff4c", + "metadata": {}, + "source": [ + "If you don't need data persistence, this is a great option for experimenting while building your AI application with Langchain." + ] + }, + { + "cell_type": "markdown", + "id": "ad6adc53-4b3f-458e-8e2e-efcc3f99f0c5", + "metadata": {}, + "source": [ + "#### Running Locally (with Data Persistence)\n", + "\n", + "You can provide the `persist_directory` argument to save your data across multiple runs of your program:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a858e77-fd6d-44f0-840f-8f71eaeae6f7", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_chroma import Chroma\n", + "\n", + "vector_store = Chroma(\n", + " collection_name=\"example_collection\",\n", + " embedding_function=embeddings,\n", + " persist_directory=\"./chroma_langchain_db\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "47bf272e-af0b-450e-8a86-3e8292273cde", + "metadata": {}, + "source": [ + "#### Connecting to a Chroma Server\n", + "\n", + "If you have a Chroma server running locally, or you have [deployed](https://docs.trychroma.com/guides/deploy/client-server-mode) one yourself, you can connect to it by providing the `host` argument.\n", + "\n", + "For example, you can start a Chroma server running locally with `chroma run`, and then connect it with `host='localhost'`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "679d619f-b8ee-4abb-8ac0-77ec859ddff1", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_chroma import Chroma\n", + "\n", + "vector_store = Chroma(\n", + " collection_name=\"example_collection\",\n", + " embedding_function=embeddings,\n", + " host=\"localhost\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e3c06ed9-c010-4764-bd6e-2a0c71201d5b", + "metadata": {}, + "source": [ + "For other deployments you can use the `port`, `ssl`, and `headers` arguments to customize your connection." + ] + }, + { + "cell_type": "markdown", + "id": "0f3238e1-ca57-482d-878d-b09bd2c8015c", + "metadata": {}, + "source": [ + "#### Chroma Cloud\n", + "\n", + "Chroma Cloud users can also build with Langchain. Provide your `Chroma` instance with your Chroma Cloud API key, tenant, and DB name:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e080d2d2-c501-467e-9842-e2045d86cdb5", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_chroma import Chroma\n", + "\n", + "vector_store = Chroma(\n", + " collection_name=\"example_collection\",\n", + " embedding_function=embeddings,\n", + " chroma_cloud_api_key=os.getenv(\"CHROMA_API_KEY\"),\n", + " tenant=os.getenv(\"CHROMA_TENANT\"),\n", + " database=os.getenv(\"CHROMA_DATABASE\"),\n", ")" ] }, @@ -111,21 +233,132 @@ "You can also initialize from a `Chroma` client, which is particularly useful if you want easier access to the underlying database." ] }, + { + "cell_type": "markdown", + "id": "38e9f893-60df-4a4f-b570-2d1c463cc1e4", + "metadata": {}, + "source": [ + "#### Running Locally (In-Memory)" + ] + }, { "cell_type": "code", - "execution_count": 3, - "id": "3fe4457f", + "execution_count": null, + "id": "09bfb62f-7c6b-43d3-a69a-0601899c6942", "metadata": {}, "outputs": [], "source": [ "import chromadb\n", "\n", - "persistent_client = chromadb.PersistentClient()\n", - "collection = persistent_client.get_or_create_collection(\"collection_name\")\n", - "collection.add(ids=[\"1\", \"2\", \"3\"], documents=[\"a\", \"b\", \"c\"])\n", + "client = chromadb.Client()" + ] + }, + { + "cell_type": "markdown", + "id": "f3eac2de-0cca-4d57-b67d-04cc78bb59c1", + "metadata": {}, + "source": [ + "#### Running Locally (with Data Persistence)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ffc7f2ad-0d6c-4911-a4cf-a82bf7649478", + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", "\n", + "client = chromadb.PersistentClient(path=\"./chroma_langchain_db\")" + ] + }, + { + "cell_type": "markdown", + "id": "41cc98d5-94f3-4a2f-903e-61c4a38d8f9c", + "metadata": {}, + "source": [ + "#### Connecting to a Chroma Server\n", + "\n", + "For example, if you are running a Chroma server locally (using `chroma run`):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb5828e3-c0a5-4f97-8d2e-23d82257743e", + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", + "\n", + "client = chromadb.HttpClient(host=\"localhost\", port=8000, ssl=False)" + ] + }, + { + "cell_type": "markdown", + "id": "254ecfdb-f247-4a3d-a52a-e515b17b7ba2", + "metadata": {}, + "source": [ + "#### Chroma Cloud" + ] + }, + { + "cell_type": "markdown", + "id": "fbbf8042-7ae7-4221-96e3-dc2048dd0f45", + "metadata": {}, + "source": [ + "After setting your `CHROMA_API_KEY`, `CHROMA_TENANT`, and `CHROMA_DATABASE`, you can simply instantiate:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89e86a01-a347-4041-a4a1-01eecd299235", + "metadata": {}, + "outputs": [], + "source": [ + "import chromadb\n", + "\n", + "client = chromadb.CloudClient()" + ] + }, + { + "cell_type": "markdown", + "id": "8fdd8bbb-45ab-43d8-bdc1-7220b14cfc52", + "metadata": {}, + "source": [ + "#### Access your Chroma DB" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6da21a1a-8d0d-4a4b-bac5-008839e89540", + "metadata": {}, + "outputs": [], + "source": [ + "collection = client.get_or_create_collection(\"collection_name\")\n", + "collection.add(ids=[\"1\", \"2\", \"3\"], documents=[\"a\", \"b\", \"c\"])" + ] + }, + { + "cell_type": "markdown", + "id": "581906ba-8082-450c-a3c4-19284539980b", + "metadata": {}, + "source": [ + "#### Create a Chroma Vectorstore" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fe4457f", + "metadata": {}, + "outputs": [], + "source": [ "vector_store_from_client = Chroma(\n", - " client=persistent_client,\n", + " client=client,\n", " collection_name=\"collection_name\",\n", " embedding_function=embeddings,\n", ")" @@ -147,30 +380,10 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "id": "da279339", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['f22ed484-6db3-4b76-adb1-18a777426cd6',\n", - " 'e0d5bab4-6453-4511-9a37-023d9d288faa',\n", - " '877d76b8-3580-4d9e-a13f-eed0fa3d134a',\n", - " '26eaccab-81ce-4c0a-8e76-bf542647df18',\n", - " 'bcaa8239-7986-4050-bf40-e14fb7dab997',\n", - " 'cdc44b38-a83f-4e49-b249-7765b334e09d',\n", - " 'a7a35354-2687-4bc2-8242-3849a4d18d34',\n", - " '8780caf1-d946-4f27-a707-67d037e9e1d8',\n", - " 'dec6af2a-7326-408f-893d-7d7d717dfda9',\n", - " '3b18e210-bb59-47a0-8e17-c8e51176ea5e']" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "from uuid import uuid4\n", "\n", @@ -265,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "ef5dbd1e", "metadata": {}, "outputs": [], @@ -301,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "56f17791", "metadata": {}, "outputs": [], @@ -327,19 +540,10 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "e2b96fcf", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* Building an exciting new project with LangChain - come check it out! [{'source': 'tweet'}]\n", - "* LangGraph is the best framework for building stateful, agentic applications! [{'source': 'tweet'}]\n" - ] - } - ], + "outputs": [], "source": [ "results = vector_store.similarity_search(\n", " \"LangChain provides abstractions to make working with LLMs easy\",\n", @@ -362,18 +566,10 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "2768a331", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* [SIM=1.726390] The stock market is down 500 points today due to fears of a recession. [{'source': 'news'}]\n" - ] - } - ], + "outputs": [], "source": [ "results = vector_store.similarity_search_with_score(\n", " \"Will it be hot tomorrow?\", k=1, filter={\"source\": \"news\"}\n", @@ -394,18 +590,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "8ea434a5", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "* I had chocolate chip pancakes and fried eggs for breakfast this morning. [{'source': 'tweet'}]\n" - ] - } - ], + "outputs": [], "source": [ "results = vector_store.similarity_search_by_vector(\n", " embedding=embeddings.embed_query(\"I love green eggs and ham!\"), k=1\n", @@ -430,21 +618,10 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "7b6f7867", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[Document(metadata={'source': 'news'}, page_content='Robbers broke into the city bank and stole $1 million in cash.')]" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "retriever = vector_store.as_retriever(\n", " search_type=\"mmr\", search_kwargs={\"k\": 1, \"fetch_k\": 5}\n", @@ -493,7 +670,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.0" } }, "nbformat": 4, diff --git a/libs/partners/chroma/langchain_chroma/vectorstores.py b/libs/partners/chroma/langchain_chroma/vectorstores.py index 419159b8f31..693f4a10cd3 100644 --- a/libs/partners/chroma/langchain_chroma/vectorstores.py +++ b/libs/partners/chroma/langchain_chroma/vectorstores.py @@ -20,13 +20,15 @@ from typing import ( import chromadb import chromadb.config import numpy as np +from chromadb import Settings +from chromadb.api import CreateCollectionConfiguration from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.utils import xor_args from langchain_core.vectorstores import VectorStore if TYPE_CHECKING: - from chromadb.api.types import ID, OneOrMany, Where, WhereDocument + from chromadb.api.types import Where, WhereDocument logger = logging.getLogger() DEFAULT_K = 4 # Number of Documents to return. @@ -167,6 +169,20 @@ class Chroma(VectorStore): Chroma client settings. persist_directory: Optional[str] Directory to persist the collection. + host: Optional[str] + Hostname of a deployed Chroma server. + port: Optional[int] + Connection port for a deployed Chroma server. Default is 8000. + ssl: Optional[bool] + Whether to establish an SSL connection with a deployed Chroma server. Default is False. + headers: Optional[dict[str, str]] + HTTP headers to send to a deployed Chroma server. + chroma_cloud_api_key: Optional[str] + Chroma Cloud API key. + tenant: Optional[str] + Tenant ID. Required for Chroma Cloud connections. Default is 'default_tenant' for local Chroma servers. + database: Optional[str] + Database name. Required for Chroma Cloud connections. Default is 'default_database'. Instantiate: .. code-block:: python @@ -284,11 +300,20 @@ class Chroma(VectorStore): collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, embedding_function: Optional[Embeddings] = None, persist_directory: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + headers: Optional[dict[str, str]] = None, + chroma_cloud_api_key: Optional[str] = None, + tenant: Optional[str] = None, + database: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, collection_metadata: Optional[dict] = None, + collection_configuration: Optional[CreateCollectionConfiguration] = None, client: Optional[chromadb.ClientAPI] = None, relevance_score_fn: Optional[Callable[[float], float]] = None, create_collection_if_not_exists: Optional[bool] = True, # noqa: FBT001, FBT002 + *, + ssl: bool = False, ) -> None: """Initialize with a Chroma client. @@ -296,8 +321,20 @@ class Chroma(VectorStore): collection_name: Name of the collection to create. embedding_function: Embedding class object. Used to embed texts. persist_directory: Directory to persist the collection. + host: Hostname of a deployed Chroma server. + port: Connection port for a deployed Chroma server. Default is 8000. + ssl: Whether to establish an SSL connection with a deployed Chroma server. + Default is False. + headers: HTTP headers to send to a deployed Chroma server. + chroma_cloud_api_key: Chroma Cloud API key. + tenant: Tenant ID. Required for Chroma Cloud connections. + Default is 'default_tenant' for local Chroma servers. + database: Database name. Required for Chroma Cloud connections. + Default is 'default_database'. client_settings: Chroma client settings collection_metadata: Collection configurations. + collection_configuration: Index configuration for the collection. + Defaults to None. client: Chroma client. Documentation: https://docs.trychroma.com/reference/python/client relevance_score_fn: Function to calculate relevance score from distance. @@ -305,37 +342,73 @@ class Chroma(VectorStore): create_collection_if_not_exists: Whether to create collection if it doesn't exist. Defaults to True. """ - if client is not None: - self._client_settings = client_settings - self._client = client - self._persist_directory = persist_directory - else: - if client_settings: - # If client_settings is provided with persist_directory specified, - # then it is "in-memory and persisting to disk" mode. - client_settings.persist_directory = ( - persist_directory or client_settings.persist_directory - ) - client_settings.is_persistent = ( - client_settings.persist_directory is not None - ) + _tenant = tenant or chromadb.DEFAULT_TENANT + _database = database or chromadb.DEFAULT_DATABASE + _settings = client_settings or Settings() - _client_settings = client_settings - elif persist_directory: - _client_settings = chromadb.config.Settings(is_persistent=True) - _client_settings.persist_directory = persist_directory - else: - _client_settings = chromadb.config.Settings() - self._client_settings = _client_settings - self._client = chromadb.Client(_client_settings) - self._persist_directory = ( - _client_settings.persist_directory or persist_directory + client_args = { + "persist_directory": persist_directory, + "host": host, + "chroma_cloud_api_key": chroma_cloud_api_key, + } + + if sum(arg is not None for arg in client_args.values()) > 1: + provided = [ + name for name, value in client_args.items() if value is not None + ] + msg = ( + f"Only one of 'persist_directory', 'host' and 'chroma_cloud_api_key' " + f"is allowed, but got {','.join(provided)}" ) + raise ValueError(msg) + + if client is not None: + self._client = client + + # PersistentClient + elif persist_directory is not None: + self._client = chromadb.PersistentClient( + path=persist_directory, + settings=_settings, + tenant=_tenant, + database=_database, + ) + + # HttpClient + elif host is not None: + _port = port or 8000 + self._client = chromadb.HttpClient( + host=host, + port=_port, + ssl=ssl, + headers=headers, + settings=_settings, + tenant=_tenant, + database=_database, + ) + + # CloudClient + elif chroma_cloud_api_key is not None: + if not tenant or not database: + msg = ( + "Must provide tenant and database values to connect to Chroma Cloud" + ) + raise ValueError(msg) + self._client = chromadb.CloudClient( + tenant=tenant, + database=database, + api_key=chroma_cloud_api_key, + settings=_settings, + ) + + else: + self._client = chromadb.Client(settings=_settings) self._embedding_function = embedding_function self._chroma_collection: Optional[chromadb.Collection] = None self._collection_name = collection_name self._collection_metadata = collection_metadata + self._collection_configuration = collection_configuration if create_collection_if_not_exists: self.__ensure_collection() else: @@ -348,6 +421,7 @@ class Chroma(VectorStore): name=self._collection_name, embedding_function=None, metadata=self._collection_metadata, + configuration=self._collection_configuration, ) @property @@ -405,7 +479,8 @@ class Chroma(VectorStore): **kwargs, ) - def encode_image(self, uri: str) -> str: + @staticmethod + def encode_image(uri: str) -> str: """Get base64 string from image URI.""" with open(uri, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") @@ -415,7 +490,6 @@ class Chroma(VectorStore): uris: list[str], metadatas: Optional[list[dict]] = None, ids: Optional[list[str]] = None, - **kwargs: Any, ) -> list[str]: """Run more images through the embeddings and add to the vectorstore. @@ -424,7 +498,6 @@ class Chroma(VectorStore): metadatas: Optional list of metadatas. When querying, you can filter on this metadata. ids: Optional list of IDs. (Items without IDs will be assigned UUIDs) - kwargs: Additional keyword arguments to pass. Returns: List of IDs of the added images. @@ -635,7 +708,7 @@ class Chroma(VectorStore): List of Documents most similar to the query vector. """ results = self.__query_collection( - query_embeddings=embedding, + query_embeddings=[embedding], n_results=k, where=filter, where_document=where_document, @@ -658,7 +731,7 @@ class Chroma(VectorStore): k: Number of Documents to return. Defaults to 4. filter: Filter by metadata. Defaults to None. where_document: dict used to filter by the documents. - E.g. {"$contains": "hello"}}. + E.g. {"$contains": "hello"}. kwargs: Additional keyword arguments to pass to Chroma collection query. Returns: @@ -666,7 +739,7 @@ class Chroma(VectorStore): in float for each. Lower score represents more similarity. """ results = self.__query_collection( - query_embeddings=embedding, + query_embeddings=[embedding], n_results=k, where=filter, where_document=where_document, @@ -765,10 +838,10 @@ class Chroma(VectorStore): """Select the relevance score function based on collections distance metric. The most similar documents will have the lowest relevance score. Default - relevance score function is euclidean distance. Distance metric must be - provided in `collection_metadata` during initialization of Chroma object. - Example: collection_metadata={"hnsw:space": "cosine"}. Available distance - metrics are: 'cosine', 'l2' and 'ip'. + relevance score function is Euclidean distance. Distance metric must be + provided in `collection_configuration` during initialization of Chroma object. + Example: collection_configuration={"hnsw": {"space": "cosine"}}. + Available distance metrics are: 'cosine', 'l2' and 'ip'. Returns: The relevance score function. @@ -779,12 +852,15 @@ class Chroma(VectorStore): if self.override_relevance_score_fn: return self.override_relevance_score_fn - distance = "l2" - distance_key = "hnsw:space" - metadata = self._collection.metadata + hnsw_config = self._collection.configuration.get("hnsw") + hnsw_distance: Optional[str] = hnsw_config.get("space") if hnsw_config else None - if metadata and distance_key in metadata: - distance = metadata[distance_key] + spann_config = self._collection.configuration.get("spann") + spann_distance: Optional[str] = ( + spann_config.get("space") if spann_config else None + ) + + distance = hnsw_distance or spann_distance if distance == "cosine": return self._cosine_relevance_score_fn @@ -826,24 +902,22 @@ class Chroma(VectorStore): Raises: ValueError: If the embedding function does not support image embeddings. """ - if self._embedding_function is None or not hasattr( - self._embedding_function, - "embed_image", + if self._embedding_function is not None and hasattr( + self._embedding_function, "embed_image" ): - msg = "The embedding function must support image embedding." - raise ValueError(msg) + # Obtain image embedding + # Assuming embed_image returns a single embedding + image_embedding = self._embedding_function.embed_image(uris=[uri]) - # Obtain image embedding - # Assuming embed_image returns a single embedding - image_embedding = self._embedding_function.embed_image(uris=[uri]) - - # Perform similarity search based on the obtained embedding - return self.similarity_search_by_vector( - embedding=image_embedding, - k=k, - filter=filter, - **kwargs, - ) + # Perform similarity search based on the obtained embedding + return self.similarity_search_by_vector( + embedding=image_embedding, + k=k, + filter=filter, + **kwargs, + ) + msg = "The embedding function must support image embedding." + raise ValueError(msg) def similarity_search_by_image_with_relevance_score( self, @@ -870,24 +944,22 @@ class Chroma(VectorStore): Raises: ValueError: If the embedding function does not support image embeddings. """ - if self._embedding_function is None or not hasattr( - self._embedding_function, - "embed_image", + if self._embedding_function is not None and hasattr( + self._embedding_function, "embed_image" ): - msg = "The embedding function must support image embedding." - raise ValueError(msg) + # Obtain image embedding + # Assuming embed_image returns a single embedding + image_embedding = self._embedding_function.embed_image(uris=[uri]) - # Obtain image embedding - # Assuming embed_image returns a single embedding - image_embedding = self._embedding_function.embed_image(uris=[uri]) - - # Perform similarity search based on the obtained embedding - return self.similarity_search_by_vector_with_relevance_scores( - embedding=image_embedding, - k=k, - filter=filter, - **kwargs, - ) + # Perform similarity search based on the obtained embedding + return self.similarity_search_by_vector_with_relevance_scores( + embedding=image_embedding, + k=k, + filter=filter, + **kwargs, + ) + msg = "The embedding function must support image embedding." + raise ValueError(msg) def max_marginal_relevance_search_by_vector( self, @@ -922,7 +994,7 @@ class Chroma(VectorStore): List of Documents selected by maximal marginal relevance. """ results = self.__query_collection( - query_embeddings=embedding, + query_embeddings=[embedding], n_results=fetch_k, where=filter, where_document=where_document, @@ -1005,7 +1077,7 @@ class Chroma(VectorStore): def get( self, - ids: Optional[OneOrMany[ID]] = None, + ids: Optional[Union[str, list[str]]] = None, where: Optional[Where] = None, limit: Optional[int] = None, offset: Optional[int] = None, @@ -1066,7 +1138,7 @@ class Chroma(VectorStore): Returns: List of Documents. - .. versionadded:: 0.2.1 + ... versionadded:: 0.2.1 """ results = self.get(ids=list(ids)) return [ @@ -1107,16 +1179,16 @@ class Chroma(VectorStore): embeddings = self._embedding_function.embed_documents(text) if hasattr( - self._collection._client, # noqa: SLF001 + self._client, "get_max_batch_size", ) or hasattr( # for Chroma 0.5.1 and above - self._collection._client, # noqa: SLF001 + self._client, "max_batch_size", ): # for Chroma 0.4.10 and above from chromadb.utils.batch_utils import create_batches for batch in create_batches( - api=self._collection._client, # noqa: SLF001 + api=self._client, ids=ids, metadatas=metadata, # type: ignore[arg-type] documents=text, @@ -1145,9 +1217,18 @@ class Chroma(VectorStore): ids: Optional[list[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + headers: Optional[dict[str, str]] = None, + chroma_cloud_api_key: Optional[str] = None, + tenant: Optional[str] = None, + database: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, client: Optional[chromadb.ClientAPI] = None, collection_metadata: Optional[dict] = None, + collection_configuration: Optional[CreateCollectionConfiguration] = None, + *, + ssl: bool = False, **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a raw documents. @@ -1159,14 +1240,26 @@ class Chroma(VectorStore): texts: List of texts to add to the collection. collection_name: Name of the collection to create. persist_directory: Directory to persist the collection. + host: Hostname of a deployed Chroma server. + port: Connection port for a deployed Chroma server. + Default is 8000. + ssl: Whether to establish an SSL connection with a deployed Chroma server. + Default is False. + headers: HTTP headers to send to a deployed Chroma server. + chroma_cloud_api_key: Chroma Cloud API key. + tenant: Tenant ID. Required for Chroma Cloud connections. + Default is 'default_tenant' for local Chroma servers. + database: Database name. Required for Chroma Cloud connections. + Default is 'default_database'. embedding: Embedding function. Defaults to None. metadatas: List of metadatas. Defaults to None. ids: List of document IDs. Defaults to None. client_settings: Chroma client settings. client: Chroma client. Documentation: https://docs.trychroma.com/reference/python/client - collection_metadata: Collection configurations. - Defaults to None. + collection_metadata: Collection configurations. Defaults to None. + collection_configuration: Index configuration for the collection. + Defaults to None. kwargs: Additional keyword arguments to initialize a Chroma client. Returns: @@ -1176,9 +1269,17 @@ class Chroma(VectorStore): collection_name=collection_name, embedding_function=embedding, persist_directory=persist_directory, + host=host, + port=port, + ssl=ssl, + headers=headers, + chroma_cloud_api_key=chroma_cloud_api_key, + tenant=tenant, + database=database, client_settings=client_settings, client=client, collection_metadata=collection_metadata, + collection_configuration=collection_configuration, **kwargs, ) if ids is None: @@ -1217,9 +1318,18 @@ class Chroma(VectorStore): ids: Optional[list[str]] = None, collection_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME, persist_directory: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + headers: Optional[dict[str, str]] = None, + chroma_cloud_api_key: Optional[str] = None, + tenant: Optional[str] = None, + database: Optional[str] = None, client_settings: Optional[chromadb.config.Settings] = None, client: Optional[chromadb.ClientAPI] = None, # Add this line collection_metadata: Optional[dict] = None, + collection_configuration: Optional[CreateCollectionConfiguration] = None, + *, + ssl: bool = False, **kwargs: Any, ) -> Chroma: """Create a Chroma vectorstore from a list of documents. @@ -1230,14 +1340,25 @@ class Chroma(VectorStore): Args: collection_name: Name of the collection to create. persist_directory: Directory to persist the collection. + host: Hostname of a deployed Chroma server. + port: Connection port for a deployed Chroma server. Default is 8000. + ssl: Whether to establish an SSL connection with a deployed Chroma server. + Default is False. + headers: HTTP headers to send to a deployed Chroma server. + chroma_cloud_api_key: Chroma Cloud API key. + tenant: Tenant ID. Required for Chroma Cloud connections. + Default is 'default_tenant' for local Chroma servers. + database: Database name. Required for Chroma Cloud connections. + Default is 'default_database'. ids : List of document IDs. Defaults to None. documents: List of documents to add to the vectorstore. embedding: Embedding function. Defaults to None. client_settings: Chroma client settings. client: Chroma client. Documentation: https://docs.trychroma.com/reference/python/client - collection_metadata: Collection configurations. - Defaults to None. + collection_metadata: Collection configurations. Defaults to None. + collection_configuration: Index configuration for the collection. + Defaults to None. kwargs: Additional keyword arguments to initialize a Chroma client. Returns: @@ -1254,9 +1375,17 @@ class Chroma(VectorStore): ids=ids, collection_name=collection_name, persist_directory=persist_directory, + host=host, + port=port, + ssl=ssl, + headers=headers, + chroma_cloud_api_key=chroma_cloud_api_key, + tenant=tenant, + database=database, client_settings=client_settings, client=client, collection_metadata=collection_metadata, + collection_configuration=collection_configuration, **kwargs, )