diff --git a/docs/docs/modules/data_connection/indexing.ipynb b/docs/docs/modules/data_connection/indexing.ipynb index 233396da408..5e6e2645d1f 100644 --- a/docs/docs/modules/data_connection/indexing.ipynb +++ b/docs/docs/modules/data_connection/indexing.ipynb @@ -60,7 +60,7 @@ " * document addition by id (`add_documents` method with `ids` argument)\n", " * delete by id (`delete` method with `ids` argument)\n", "\n", - "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `DashVector`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `MyScale`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `ScaNN`, `SupabaseVectorStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n", + "Compatible Vectorstores: `AnalyticDB`, `AstraDB`, `AwaDB`, `Bagel`, `Cassandra`, `Chroma`, `DashVector`, `DatabricksVectorSearch`, `DeepLake`, `Dingo`, `ElasticVectorSearch`, `ElasticsearchStore`, `FAISS`, `MyScale`, `PGVector`, `Pinecone`, `Qdrant`, `Redis`, `ScaNN`, `SupabaseVectorStore`, `TimescaleVector`, `Vald`, `Vearch`, `VespaStore`, `Weaviate`, `ZepVectorStore`.\n", " \n", "## Caution\n", "\n", diff --git a/libs/langchain/langchain/vectorstores/__init__.py b/libs/langchain/langchain/vectorstores/__init__.py index 863c7024985..6c79ccfdf40 100644 --- a/libs/langchain/langchain/vectorstores/__init__.py +++ b/libs/langchain/langchain/vectorstores/__init__.py @@ -140,6 +140,12 @@ def _import_dashvector() -> Any: return DashVector +def _import_databricks_vector_search() -> Any: + from langchain.vectorstores.databricks_vector_search import DatabricksVectorSearch + + return DatabricksVectorSearch + + def _import_deeplake() -> Any: from langchain.vectorstores.deeplake import DeepLake @@ -461,6 +467,8 @@ def __getattr__(name: str) -> Any: return _import_clickhouse() elif name == "DashVector": return _import_dashvector() + elif name == "DatabricksVectorSearch": + return _import_databricks_vector_search() elif name == "DeepLake": return _import_deeplake() elif name == "Dingo": @@ -575,6 +583,7 @@ __all__ = [ "Clickhouse", "ClickhouseSettings", "DashVector", + "DatabricksVectorSearch", "DeepLake", "Dingo", "DocArrayHnswSearch", diff --git a/libs/langchain/langchain/vectorstores/databricks_vector_search.py b/libs/langchain/langchain/vectorstores/databricks_vector_search.py new file mode 100644 index 00000000000..48ddba4c92d --- /dev/null +++ b/libs/langchain/langchain/vectorstores/databricks_vector_search.py @@ -0,0 +1,473 @@ +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Type + +from langchain_core.documents import Document +from langchain_core.embeddings import Embeddings +from langchain_core.vectorstores import VST, VectorStore + +if TYPE_CHECKING: + from databricks.vector_search.client import VectorSearchIndex + +logger = logging.getLogger(__name__) + + +class DatabricksVectorSearch(VectorStore): + """`Databricks Vector Search` vector store. + + To use, you should have the ``databricks-vectorsearch`` python package installed. + + Example: + .. code-block:: python + + from langchain.vectorstores import DatabricksVectorSearch + from databricks.vector_search.client import VectorSearchClient + + vs_client = VectorSearchClient() + vs_index = vs_client.get_index( + endpoint_name="vs_endpoint", + index_name="ml.llm.index" + ) + vectorstore = DatabricksVectorSearch(vs_index) + + Args: + index: A Databricks Vector Search index object. + embedding: The embedding model. + Required for direct-access index or delta-sync index + with self-managed embeddings. + text_column: The name of the text column to use for the embeddings. + Required for direct-access index or delta-sync index + with self-managed embeddings. + Make sure the text column specified is in the index. + columns: The list of column names to get when doing the search. + Defaults to ``[primary_key, text_column]``. + + Delta-sync index with Databricks-managed embeddings manages the ingestion, deletion, + and embedding for you. + Manually ingestion/deletion of the documents/texts is not supported for delta-sync + index. + + If you want to use a delta-sync index with self-managed embeddings, you need to + provide the embedding model and text column name to use for the embeddings. + + Example: + .. code-block:: python + + from langchain.vectorstores import DatabricksVectorSearch + from databricks.vector_search.client import VectorSearchClient + from langchain.embeddings.openai import OpenAIEmbeddings + + vs_client = VectorSearchClient() + vs_index = vs_client.get_index( + endpoint_name="vs_endpoint", + index_name="ml.llm.index" + ) + vectorstore = DatabricksVectorSearch( + index=vs_index, + embedding=OpenAIEmbeddings(), + text_column="document_content" + ) + + If you want to manage the documents ingestion/deletion yourself, you can use a + direct-access index. + + Example: + .. code-block:: python + + from langchain.vectorstores import DatabricksVectorSearch + from databricks.vector_search.client import VectorSearchClient + from langchain.embeddings.openai import OpenAIEmbeddings + + vs_client = VectorSearchClient() + vs_index = vs_client.get_index( + endpoint_name="vs_endpoint", + index_name="ml.llm.index" + ) + vectorstore = DatabricksVectorSearch( + index=vs_index, + embedding=OpenAIEmbeddings(), + text_column="document_content" + ) + vectorstore.add_texts( + texts=["text1", "text2"] + ) + + For more information on Databricks Vector Search, see `Databricks Vector Search + documentation `. + + """ + + def __init__( + self, + index: VectorSearchIndex, + *, + embedding: Optional[Embeddings] = None, + text_column: Optional[str] = None, + columns: Optional[List[str]] = None, + ): + try: + from databricks.vector_search.client import VectorSearchIndex + except ImportError as e: + raise ImportError( + "Could not import databricks-vectorsearch python package. " + "Please install it with `pip install databricks-vectorsearch`." + ) from e + # index + self.index = index + if not isinstance(index, VectorSearchIndex): + raise TypeError("index must be of type VectorSearchIndex.") + + # index_details + index_details = self.index.describe() + self.primary_key = index_details["primary_key"] + self.index_type = index_details.get("index_type") + self._delta_sync_index_spec = index_details.get("delta_sync_index_spec", dict()) + self._direct_access_index_spec = index_details.get( + "direct_access_index_spec", dict() + ) + + # text_column + if self._is_databricks_managed_embeddings(): + index_source_column = self._embedding_source_column_name() + # check if input text column matches the source column of the index + if text_column is not None and text_column != index_source_column: + raise ValueError( + f"text_column '{text_column}' does not match with the " + f"source column of the index: '{index_source_column}'." + ) + self.text_column = index_source_column + else: + self._require_arg(text_column, "text_column") + self.text_column = text_column + + # columns + self.columns = columns or [] + # add primary key column and source column if not in columns + if self.primary_key not in self.columns: + self.columns.append(self.primary_key) + if self.text_column and self.text_column not in self.columns: + self.columns.append(self.text_column) + + # Validate specified columns are in the index + if self._is_direct_access_index(): + index_schema = self._index_schema() + if index_schema: + for col in self.columns: + if col not in index_schema: + raise ValueError( + f"column '{col}' is not in the index's schema." + ) + + # embedding model + if not self._is_databricks_managed_embeddings(): + # embedding model is required for direct-access index + # or delta-sync index with self-managed embedding + self._require_arg(embedding, "embedding") + self._embedding = embedding + # validate dimension matches + index_embedding_dimension = self._embedding_vector_column_dimension() + if index_embedding_dimension is not None: + inferred_embedding_dimension = self._infer_embedding_dimension() + if inferred_embedding_dimension != index_embedding_dimension: + raise ValueError( + f"embedding model's dimension '{inferred_embedding_dimension}' " + f"does not match with the index's dimension " + f"'{index_embedding_dimension}'." + ) + else: + if embedding is not None: + logger.warning( + "embedding model is not used in delta-sync index with " + "Databricks-managed embeddings." + ) + self._embedding = None + + @classmethod + def from_texts( + cls: Type[VST], + texts: List[str], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> VST: + raise NotImplementedError( + "`from_texts` is not supported. " + "Use `add_texts` to add to existing direct-access index." + ) + + def add_texts( + self, + texts: Iterable[str], + metadatas: Optional[List[dict]] = None, + ids: Optional[List[Any]] = None, + **kwargs: Any, + ) -> List[str]: + """Add texts to the index. + + Only support direct-access index. + + Args: + texts: List of texts to add. + metadatas: List of metadata for each text. Defaults to None. + ids: List of ids for each text. Defaults to None. + If not provided, a random uuid will be generated for each text. + + Returns: + List of ids from adding the texts into the index. + """ + self._op_require_direct_access_index("add_texts") + assert self.embeddings is not None, "embedding model is required." + # Wrap to list if input texts is a single string + if isinstance(texts, str): + texts = [texts] + texts = list(texts) + vectors = self.embeddings.embed_documents(texts) + ids = ids or [str(uuid.uuid4()) for _ in texts] + metadatas = metadatas or [{} for _ in texts] + + updates = [ + { + self.primary_key: id_, + self.text_column: text, + self._embedding_vector_column_name(): vector, + **metadata, + } + for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas) + ] + + upsert_resp = self.index.upsert(updates) + if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"): + failed_ids = upsert_resp.get("result", dict()).get( + "failed_primary_keys", [] + ) + if upsert_resp.get("status") == "FAILURE": + logger.error("Failed to add texts to the index.") + else: + logger.warning("Some texts failed to be added to the index.") + return [id_ for id_ in ids if id_ not in failed_ids] + + return ids + + @property + def embeddings(self) -> Optional[Embeddings]: + """Access the query embedding object if available.""" + return self._embedding + + def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]: + """Delete documents from the index. + + Only support direct-access index. + + Args: + ids: List of ids of documents to delete. + + Returns: + True if successful. + """ + self._op_require_direct_access_index("delete") + if ids is None: + raise ValueError("ids must be provided.") + self.index.delete(ids) + return True + + def similarity_search( + self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any + ) -> List[Document]: + """Return docs most similar to query. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filters: Filters to apply to the query. Defaults to None. + + Returns: + List of Documents most similar to the embedding. + """ + docs_with_score = self.similarity_search_with_score( + query=query, k=k, filters=filters, **kwargs + ) + return [doc for doc, _ in docs_with_score] + + def similarity_search_with_score( + self, query: str, k: int = 4, filters: Optional[Any] = None, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs most similar to query, along with scores. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filters: Filters to apply to the query. Defaults to None. + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._is_databricks_managed_embeddings(): + query_text = query + query_vector = None + else: + assert self.embeddings is not None, "embedding model is required." + query_text = None + query_vector = self.embeddings.embed_query(query) + + search_resp = self.index.similarity_search( + columns=self.columns, + query_text=query_text, + query_vector=query_vector, + filters=filters, + num_results=k, + ) + return self._parse_search_response(search_resp) + + def similarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filters: Optional[Any] = None, + **kwargs: Any, + ) -> List[Document]: + """Return docs most similar to embedding vector. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filters: Filters to apply to the query. Defaults to None. + + Returns: + List of Documents most similar to the embedding. + """ + docs_with_score = self.similarity_search_by_vector_with_score( + embedding=embedding, k=k, filters=filters, **kwargs + ) + return [doc for doc, _ in docs_with_score] + + def similarity_search_by_vector_with_score( + self, + embedding: List[float], + k: int = 4, + filters: Optional[Any] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs most similar to embedding vector, along with scores. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + filters: Filters to apply to the query. Defaults to None. + + Returns: + List of Documents most similar to the embedding and score for each. + """ + if self._is_databricks_managed_embeddings(): + raise ValueError( + "`similarity_search_by_vector` is not supported for index with " + "Databricks-managed embeddings." + ) + search_resp = self.index.similarity_search( + columns=self.columns, + query_vector=embedding, + filters=filters, + num_results=k, + ) + return self._parse_search_response(search_resp) + + def _parse_search_response(self, search_resp: dict) -> List[Tuple[Document, float]]: + """Parse the search response into a list of Documents with score.""" + columns = [ + col["name"] + for col in search_resp.get("manifest", dict()).get("columns", []) + ] + docs_with_score = [] + for result in search_resp.get("result", dict()).get("data_array", []): + doc_id = result[columns.index(self.primary_key)] + text_content = result[columns.index(self.text_column)] + metadata = { + col: value + for col, value in zip(columns[:-1], result[:-1]) + if col not in [self.primary_key, self.text_column] + } + metadata[self.primary_key] = doc_id + score = result[-1] + doc = Document(page_content=text_content, metadata=metadata) + docs_with_score.append((doc, score)) + return docs_with_score + + def _index_schema(self) -> Optional[dict]: + """Return the index schema as a dictionary. + Return None if no schema found. + """ + if self._is_direct_access_index(): + schema_json = self._direct_access_index_spec.get("schema_json") + if schema_json is not None: + return json.loads(schema_json) + return None + + def _embedding_vector_column_name(self) -> Optional[str]: + """Return the name of the embedding vector column. + None if the index is not a self-managed embedding index. + """ + return self._embedding_vector_column().get("name") + + def _embedding_vector_column_dimension(self) -> Optional[int]: + """Return the dimension of the embedding vector column. + None if the index is not a self-managed embedding index. + """ + return self._embedding_vector_column().get("embedding_dimension") + + def _embedding_vector_column(self) -> dict: + """Return the embedding vector column configs as a dictionary. + Empty if the index is not a self-managed embedding index. + """ + index_spec = ( + self._delta_sync_index_spec + if self._is_delta_sync_index() + else self._direct_access_index_spec + ) + return next(iter(index_spec.get("embedding_vector_columns") or list()), dict()) + + def _embedding_source_column_name(self) -> Optional[str]: + """Return the name of the embedding source column. + None if the index is not a Databricks-managed embedding index. + """ + return self._embedding_source_column().get("name") + + def _embedding_source_column(self) -> dict: + """Return the embedding source column configs as a dictionary. + Empty if the index is not a Databricks-managed embedding index. + """ + index_spec = self._delta_sync_index_spec + return next(iter(index_spec.get("embedding_source_columns") or list()), dict()) + + def _is_delta_sync_index(self) -> bool: + """Return True if the index is a delta-sync index.""" + return self.index_type == "DELTA_SYNC" + + def _is_direct_access_index(self) -> bool: + """Return True if the index is a direct-access index.""" + return self.index_type == "DIRECT_ACCESS" + + def _is_databricks_managed_embeddings(self) -> bool: + """Return True if the embeddings are managed by Databricks Vector Search.""" + return ( + self._is_delta_sync_index() + and self._embedding_source_column_name() is not None + ) + + def _infer_embedding_dimension(self) -> int: + """Infer the embedding dimension from the embedding function.""" + assert self.embeddings is not None, "embedding model is required." + return len(self.embeddings.embed_query("test")) + + def _op_require_direct_access_index(self, op_name: str) -> None: + """ + Raise ValueError if the operation is not supported for direct-access index.""" + if not self._is_direct_access_index(): + raise ValueError(f"`{op_name}` is only supported for direct-access index.") + + @staticmethod + def _require_arg(arg: Any, arg_name: str) -> None: + """Raise ValueError if the required arg with name `arg_name` is None.""" + if not arg: + raise ValueError(f"`{arg_name}` is required for this index.") diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 933e5b87064..3db514d2623 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1568,6 +1568,17 @@ click = ">=4.0" [package.extras] test = ["pytest-cov"] +[[package]] +name = "cloudpickle" +version = "2.2.1" +description = "Extended pickling support for Python objects" +optional = true +python-versions = ">=3.6" +files = [ + {file = "cloudpickle-2.2.1-py3-none-any.whl", hash = "sha256:61f594d1f4c295fa5cd9014ceb3a1fc4a70b0de1164b94fbc2d854ccba056f9f"}, + {file = "cloudpickle-2.2.1.tar.gz", hash = "sha256:d89684b8de9e34a2a43b3460fbca07d09d6e25ce858df4d5a44240403b6178f5"}, +] + [[package]] name = "codespell" version = "2.2.6" @@ -1794,6 +1805,41 @@ grpcio = [ numpy = "*" protobuf = ">=3.8.0,<4.0.0" +[[package]] +name = "databricks-cli" +version = "0.18.0" +description = "A command line interface for Databricks" +optional = true +python-versions = ">=3.7" +files = [ + {file = "databricks-cli-0.18.0.tar.gz", hash = "sha256:87569709eda9af3e9db8047b691e420b5e980c62ef01675575c0d2b9b4211eb7"}, + {file = "databricks_cli-0.18.0-py2.py3-none-any.whl", hash = "sha256:1176a5f42d3e8af4abfc915446fb23abc44513e325c436725f5898cbb9e3384b"}, +] + +[package.dependencies] +click = ">=7.0" +oauthlib = ">=3.1.0" +pyjwt = ">=1.7.0" +requests = ">=2.17.3" +six = ">=1.10.0" +tabulate = ">=0.7.7" +urllib3 = ">=1.26.7,<3" + +[[package]] +name = "databricks-vectorsearch" +version = "0.21" +description = "Databricks Vector Search Client" +optional = true +python-versions = ">=3.7" +files = [ + {file = "databricks_vectorsearch-0.21-py3-none-any.whl", hash = "sha256:18265affdb38d44e7ec4cc95f8267379c5109bdb6e75bb61a729f126b2433868"}, +] + +[package.dependencies] +mlflow-skinny = ">=2.4.0,<3" +protobuf = ">=3.12.0,<5" +requests = ">=2" + [[package]] name = "dataclasses-json" version = "0.6.1" @@ -4719,6 +4765,39 @@ files = [ {file = "mistune-3.0.2.tar.gz", hash = "sha256:fc7f93ded930c92394ef2cb6f04a8aabab4117a91449e72dcc8dfa646a508be8"}, ] +[[package]] +name = "mlflow-skinny" +version = "2.8.1" +description = "MLflow: A Platform for ML Development and Productionization" +optional = true +python-versions = ">=3.8" +files = [ + {file = "mlflow-skinny-2.8.1.tar.gz", hash = "sha256:8f46462e2df5ffd93a7f7d92ad1d3d7335adbe5e8e999543a3879963ae576d33"}, + {file = "mlflow_skinny-2.8.1-py3-none-any.whl", hash = "sha256:8e2a1a5b8f1e2a3437c1fab972115a4df25934cd07cd83b8eb70202af8ad814a"}, +] + +[package.dependencies] +click = ">=7.0,<9" +cloudpickle = "<3" +databricks-cli = ">=0.8.7,<1" +entrypoints = "<1" +gitpython = ">=2.1.0,<4" +importlib-metadata = ">=3.7.0,<4.7.0 || >4.7.0,<7" +packaging = "<24" +protobuf = ">=3.12.0,<5" +pytz = "<2024" +pyyaml = ">=5.1,<7" +requests = ">=2.17.3,<3" +sqlparse = ">=0.4.0,<1" + +[package.extras] +aliyun-oss = ["aliyunstoreplugin"] +databricks = ["azure-storage-file-datalake (>12)", "boto3 (>1)", "google-cloud-storage (>=1.30.0)"] +extras = ["azureml-core (>=1.2.0)", "boto3", "google-cloud-storage (>=1.30.0)", "kubernetes", "mlserver (>=1.2.0,!=1.3.1)", "mlserver-mlflow (>=1.2.0,!=1.3.1)", "prometheus-flask-exporter", "pyarrow", "pysftp", "requests-auth-aws-sigv4", "virtualenv"] +gateway = ["aiohttp (<4)", "boto3 (>=1.28.56,<2)", "fastapi (<1)", "pydantic (>=1.0,<3)", "uvicorn[standard] (<1)", "watchfiles (<1)"] +sqlserver = ["mlflow-dbstore"] +xethub = ["mlflow-xethub"] + [[package]] name = "mmh3" version = "3.1.0" @@ -9271,6 +9350,22 @@ files = [ {file = "sqlparams-5.1.0.tar.gz", hash = "sha256:1abe87a0684567265b2b86f5a482d5c37db237c0268d4c81774ffedce4300199"}, ] +[[package]] +name = "sqlparse" +version = "0.4.4" +description = "A non-validating SQL parser." +optional = true +python-versions = ">=3.5" +files = [ + {file = "sqlparse-0.4.4-py3-none-any.whl", hash = "sha256:5430a4fe2ac7d0f93e66f1efc6e1338a41884b7ddf2a350cedd20ccc4d9d28f3"}, + {file = "sqlparse-0.4.4.tar.gz", hash = "sha256:d446183e84b8349fa3061f0fe7f06ca94ba65b426946ffebe6e3e8295332420c"}, +] + +[package.extras] +dev = ["build", "flake8"] +doc = ["sphinx"] +test = ["pytest", "pytest-cov"] + [[package]] name = "stack-data" version = "0.6.3" @@ -9368,6 +9463,20 @@ files = [ [package.dependencies] pytest = ">=7.0.0,<8.0.0" +[[package]] +name = "tabulate" +version = "0.9.0" +description = "Pretty-print tabular data" +optional = true +python-versions = ">=3.7" +files = [ + {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"}, + {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"}, +] + +[package.extras] +widechars = ["wcwidth"] + [[package]] name = "telethon" version = "1.31.1" @@ -11075,7 +11184,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "databricks-vectorsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -11085,4 +11194,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "37e62f668e1acddc4e462fdac5f694af3916b6edbd1ccde0a54c9a57524d6c92" +content-hash = "d57493dcdb7c864d71aa43463a57491f0c9cbd8fa8674d21c0b11117e8d7ea67" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 52d43f1e48f..d18b343b499 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -144,6 +144,7 @@ google-cloud-documentai = {version = "^2.20.1", optional = true} fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"} javelin-sdk = {version = "^0.1.8", optional = true} msal = {version = "^1.25.0", optional = true} +databricks-vectorsearch = {version = "^0.21", optional = true} [tool.poetry.group.test.dependencies] @@ -381,6 +382,7 @@ extended_testing = [ "rspace_client", "fireworks-ai", "javelin-sdk", + "databricks-vectorsearch", ] [tool.ruff] diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index aacc138d95b..6d2022989de 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -1123,6 +1123,7 @@ def test_compatible_vectorstore_documentation() -> None: "Cassandra", "Chroma", "DashVector", + "DatabricksVectorSearch", "DeepLake", "Dingo", "ElasticVectorSearch", diff --git a/libs/langchain/tests/unit_tests/vectorstores/test_databricks_vector_search.py b/libs/langchain/tests/unit_tests/vectorstores/test_databricks_vector_search.py new file mode 100644 index 00000000000..699a9f8a5a0 --- /dev/null +++ b/libs/langchain/tests/unit_tests/vectorstores/test_databricks_vector_search.py @@ -0,0 +1,526 @@ +import random +import uuid +from typing import List +from unittest.mock import MagicMock + +import pytest + +from langchain.vectorstores import DatabricksVectorSearch +from tests.integration_tests.vectorstores.fake_embeddings import ( + FakeEmbeddings, + fake_texts, +) + +DEFAULT_VECTOR_DIMENSION = 4 + + +class FakeEmbeddingsWithDimension(FakeEmbeddings): + """Fake embeddings functionality for testing.""" + + def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION): + super().__init__() + self.dimension = dimension + + def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]: + """Return simple embeddings.""" + return [ + [float(1.0)] * (self.dimension - 1) + [float(i)] + for i in range(len(embedding_texts)) + ] + + def embed_query(self, text: str) -> List[float]: + """Return simple embeddings.""" + return [float(1.0)] * (self.dimension - 1) + [float(0.0)] + + +DEFAULT_EMBEDDING_MODEL = FakeEmbeddingsWithDimension() +DEFAULT_TEXT_COLUMN = "text" +DEFAULT_VECTOR_COLUMN = "text_vector" +DEFAULT_PRIMARY_KEY = "id" + +DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS = { + "name": "ml.llm.index", + "endpoint_name": "vector_search_endpoint", + "index_type": "DELTA_SYNC", + "primary_key": DEFAULT_PRIMARY_KEY, + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_source_columns": [ + { + "name": DEFAULT_TEXT_COLUMN, + "embedding_model_endpoint_name": "openai-text-embedding", + } + ], + }, +} + +DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS = { + "name": "ml.llm.index", + "endpoint_name": "vector_search_endpoint", + "index_type": "DELTA_SYNC", + "primary_key": DEFAULT_PRIMARY_KEY, + "delta_sync_index_spec": { + "source_table": "ml.llm.source_table", + "pipeline_type": "CONTINUOUS", + "embedding_vector_columns": [ + { + "name": DEFAULT_VECTOR_COLUMN, + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + }, +} + +DIRECT_ACCESS_INDEX = { + "name": "ml.llm.index", + "endpoint_name": "vector_search_endpoint", + "index_type": "DIRECT_ACCESS", + "primary_key": DEFAULT_PRIMARY_KEY, + "direct_access_index_spec": { + "embedding_vector_columns": [ + { + "name": DEFAULT_VECTOR_COLUMN, + "embedding_dimension": DEFAULT_VECTOR_DIMENSION, + } + ], + "schema_json": f"{{" + f'"{DEFAULT_PRIMARY_KEY}": "int", ' + f'"feat1": "str", ' + f'"feat2": "float", ' + f'"text": "string", ' + f'"{DEFAULT_VECTOR_COLUMN}": "array"' + f"}}", + }, +} + +ALL_INDEXES = [ + DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, + DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, + DIRECT_ACCESS_INDEX, +] + +EXAMPLE_SEARCH_RESPONSE = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": DEFAULT_PRIMARY_KEY}, + {"name": DEFAULT_TEXT_COLUMN}, + {"name": "score"}, + ], + }, + "result": { + "row_count": len(fake_texts), + "data_array": sorted( + [[str(uuid.uuid4()), s, random.uniform(0, 1)] for s in fake_texts], + key=lambda x: x[2], # type: ignore + reverse=True, + ), + }, + "next_page_token": "", +} + + +def mock_index(index_details: dict) -> MagicMock: + from databricks.vector_search.client import VectorSearchIndex + + index = MagicMock(spec=VectorSearchIndex) + index.describe.return_value = index_details + return index + + +def default_databricks_vector_search(index: MagicMock) -> DatabricksVectorSearch: + return DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + ) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_delta_sync_with_managed_embeddings() -> None: + index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) + vectorsearch = DatabricksVectorSearch(index) + assert vectorsearch.index == index + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_delta_sync_with_self_managed_embeddings() -> None: + index = mock_index(DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS) + vectorsearch = DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + ) + assert vectorsearch.index == index + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_direct_access_index() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + ) + assert vectorsearch.index == index + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_fail_no_index() -> None: + with pytest.raises(TypeError): + DatabricksVectorSearch() + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_fail_index_none() -> None: + with pytest.raises(TypeError) as ex: + DatabricksVectorSearch(None) + assert "index must be of type VectorSearchIndex." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_init_fail_text_column_mismatch() -> None: + index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) + with pytest.raises(ValueError) as ex: + DatabricksVectorSearch( + index, + text_column="some_other_column", + ) + assert ( + f"text_column 'some_other_column' does not match with the source column of the " + f"index: '{DEFAULT_TEXT_COLUMN}'." in str(ex.value) + ) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] +) +def test_init_fail_no_text_column(index_details: dict) -> None: + index = mock_index(index_details) + with pytest.raises(ValueError) as ex: + DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + ) + assert "`text_column` is required for this index." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize("index_details", [DIRECT_ACCESS_INDEX]) +def test_init_fail_columns_not_in_schema(index_details: dict) -> None: + index = mock_index(index_details) + with pytest.raises(ValueError) as ex: + DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + columns=["some_random_column"], + ) + assert "column 'some_random_column' is not in the index's schema." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] +) +def test_init_fail_no_embedding(index_details: dict) -> None: + index = mock_index(index_details) + with pytest.raises(ValueError) as ex: + DatabricksVectorSearch( + index, + text_column=DEFAULT_TEXT_COLUMN, + ) + assert "`embedding` is required for this index." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] +) +def test_init_fail_embedding_dim_mismatch(index_details: dict) -> None: + index = mock_index(index_details) + with pytest.raises(ValueError) as ex: + DatabricksVectorSearch( + index, + text_column=DEFAULT_TEXT_COLUMN, + embedding=FakeEmbeddingsWithDimension(DEFAULT_VECTOR_DIMENSION + 1), + ) + assert ( + f"embedding model's dimension '{DEFAULT_VECTOR_DIMENSION + 1}' does not match " + f"with the index's dimension '{DEFAULT_VECTOR_DIMENSION}'" + ) in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_from_texts_not_supported() -> None: + with pytest.raises(NotImplementedError) as ex: + DatabricksVectorSearch.from_texts(fake_texts, FakeEmbeddings()) + assert ( + "`from_texts` is not supported. " + "Use `add_texts` to add to existing direct-access index." + ) in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", + [DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS], +) +def test_add_texts_not_supported_for_delta_sync_index(index_details: dict) -> None: + index = mock_index(index_details) + vectorsearch = default_databricks_vector_search(index) + with pytest.raises(ValueError) as ex: + vectorsearch.add_texts(fake_texts) + assert "`add_texts` is only supported for direct-access index." in str(ex.value) + + +def is_valid_uuid(val: str) -> bool: + try: + uuid.UUID(str(val)) + return True + except ValueError: + return False + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_add_texts() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + ) + ids = [idx for idx, i in enumerate(fake_texts)] + vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) + + added_ids = vectorsearch.add_texts(fake_texts, ids=ids) + index.upsert.assert_called_once_with( + [ + { + DEFAULT_PRIMARY_KEY: id_, + DEFAULT_TEXT_COLUMN: text, + DEFAULT_VECTOR_COLUMN: vector, + } + for text, vector, id_ in zip(fake_texts, vectors, ids) + ] + ) + assert len(added_ids) == len(fake_texts) + assert added_ids == ids + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_add_texts_handle_single_text() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = DatabricksVectorSearch( + index, + embedding=DEFAULT_EMBEDDING_MODEL, + text_column=DEFAULT_TEXT_COLUMN, + ) + vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) + + added_ids = vectorsearch.add_texts(fake_texts[0]) + index.upsert.assert_called_once_with( + [ + { + DEFAULT_PRIMARY_KEY: id_, + DEFAULT_TEXT_COLUMN: text, + DEFAULT_VECTOR_COLUMN: vector, + } + for text, vector, id_ in zip(fake_texts, vectors, added_ids) + ] + ) + assert len(added_ids) == 1 + assert is_valid_uuid(added_ids[0]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_add_texts_with_default_id() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = default_databricks_vector_search(index) + vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) + + added_ids = vectorsearch.add_texts(fake_texts) + index.upsert.assert_called_once_with( + [ + { + DEFAULT_PRIMARY_KEY: id_, + DEFAULT_TEXT_COLUMN: text, + DEFAULT_VECTOR_COLUMN: vector, + } + for text, vector, id_ in zip(fake_texts, vectors, added_ids) + ] + ) + assert len(added_ids) == len(fake_texts) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_add_texts_with_metadata() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = default_databricks_vector_search(index) + vectors = DEFAULT_EMBEDDING_MODEL.embed_documents(fake_texts) + metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(fake_texts))] + + added_ids = vectorsearch.add_texts(fake_texts, metadatas=metadatas) + index.upsert.assert_called_once_with( + [ + { + DEFAULT_PRIMARY_KEY: id_, + DEFAULT_TEXT_COLUMN: text, + DEFAULT_VECTOR_COLUMN: vector, + **metadata, + } + for text, vector, id_, metadata in zip( + fake_texts, vectors, added_ids, metadatas + ) + ] + ) + assert len(added_ids) == len(fake_texts) + assert all([is_valid_uuid(id_) for id_ in added_ids]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", + [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX], +) +def test_embeddings_property(index_details: dict) -> None: + index = mock_index(index_details) + vectorsearch = default_databricks_vector_search(index) + assert vectorsearch.embeddings == DEFAULT_EMBEDDING_MODEL + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", + [DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS, DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS], +) +def test_delete_not_supported_for_delta_sync_index(index_details: dict) -> None: + index = mock_index(index_details) + vectorsearch = default_databricks_vector_search(index) + with pytest.raises(ValueError) as ex: + vectorsearch.delete(["some id"]) + assert "`delete` is only supported for direct-access index." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_delete() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = default_databricks_vector_search(index) + + vectorsearch.delete(["some id"]) + index.delete.assert_called_once_with(["some id"]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_delete_fail_no_ids() -> None: + index = mock_index(DIRECT_ACCESS_INDEX) + vectorsearch = default_databricks_vector_search(index) + + with pytest.raises(ValueError) as ex: + vectorsearch.delete() + assert "ids must be provided." in str(ex.value) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize("index_details", ALL_INDEXES) +def test_similarity_search(index_details: dict) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query = "foo" + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search(query, k=limit, filters=filters) + if index_details == DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS: + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_text=query, + query_vector=None, + filters=filters, + num_results=limit, + ) + else: + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_text=None, + query_vector=DEFAULT_EMBEDDING_MODEL.embed_query(query), + filters=filters, + num_results=limit, + ) + assert len(search_result) == len(fake_texts) + assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) + assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize( + "index_details", [DELTA_SYNC_INDEX_SELF_MANAGED_EMBEDDINGS, DIRECT_ACCESS_INDEX] +) +def test_similarity_search_by_vector(index_details: dict) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + search_result = vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filters=filters + ) + index.similarity_search.assert_called_once_with( + columns=[DEFAULT_PRIMARY_KEY, DEFAULT_TEXT_COLUMN], + query_vector=query_embedding, + filters=filters, + num_results=limit, + ) + assert len(search_result) == len(fake_texts) + assert sorted([d.page_content for d in search_result]) == sorted(fake_texts) + assert all([DEFAULT_PRIMARY_KEY in d.metadata for d in search_result]) + + +@pytest.mark.requires("databricks", "databricks.vector_search") +@pytest.mark.parametrize("index_details", ALL_INDEXES) +def test_similarity_search_empty_result(index_details: dict) -> None: + index = mock_index(index_details) + index.similarity_search.return_value = { + "manifest": { + "column_count": 3, + "columns": [ + {"name": DEFAULT_PRIMARY_KEY}, + {"name": DEFAULT_TEXT_COLUMN}, + {"name": "score"}, + ], + }, + "result": { + "row_count": 0, + "data_array": [], + }, + "next_page_token": "", + } + vectorsearch = default_databricks_vector_search(index) + + search_result = vectorsearch.similarity_search("foo") + assert len(search_result) == 0 + + +@pytest.mark.requires("databricks", "databricks.vector_search") +def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None: + index = mock_index(DELTA_SYNC_INDEX_MANAGED_EMBEDDINGS) + index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + vectorsearch = default_databricks_vector_search(index) + query_embedding = DEFAULT_EMBEDDING_MODEL.embed_query("foo") + filters = {"some filter": True} + limit = 7 + + with pytest.raises(ValueError) as ex: + vectorsearch.similarity_search_by_vector( + query_embedding, k=limit, filters=filters + ) + assert ( + "`similarity_search_by_vector` is not supported for index with " + "Databricks-managed embeddings." in str(ex.value) + ) diff --git a/libs/langchain/tests/unit_tests/vectorstores/test_public_api.py b/libs/langchain/tests/unit_tests/vectorstores/test_public_api.py index 3773b933c17..24669e53374 100644 --- a/libs/langchain/tests/unit_tests/vectorstores/test_public_api.py +++ b/libs/langchain/tests/unit_tests/vectorstores/test_public_api.py @@ -17,6 +17,7 @@ _EXPECTED = [ "Clickhouse", "ClickhouseSettings", "DashVector", + "DatabricksVectorSearch", "DeepLake", "Dingo", "DocArrayHnswSearch",