databricks: mv to partner repo (#25788)

This commit is contained in:
Erick Friis 2024-08-27 18:51:17 -07:00 committed by GitHub
parent 2e5c379632
commit 1023fbc98a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 0 additions and 5437 deletions

View File

@ -1 +0,0 @@
__pycache__

View File

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -1,62 +0,0 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/
# unit tests are run with the --disable-socket flag to prevent network calls
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
# integration tests are run without the --disable-socket flag to allow network calls
integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/databricks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_databricks
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff check --select I $(PYTHON_FILES)
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_databricks -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -1,24 +0,0 @@
# langchain-databricks
This package contains the LangChain integration with Databricks
## Installation
```bash
pip install -U langchain-databricks
```
And you should configure credentials by setting the following environment variables:
* TODO: fill this out
## Chat Models
`ChatDatabricks` class exposes chat models from Databricks.
```python
from langchain_databricks import ChatDatabricks
llm = ChatDatabricks()
llm.invoke("Sing a ballad of LangChain.")
```

View File

@ -1,19 +0,0 @@
from importlib import metadata
from langchain_databricks.chat_models import ChatDatabricks
from langchain_databricks.embeddings import DatabricksEmbeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
__all__ = [
"ChatDatabricks",
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"__version__",
]

View File

@ -1,556 +0,0 @@
"""Databricks chat models."""
import json
import logging
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_databricks.utils import get_deployment_client
logger = logging.getLogger(__name__)
class ChatDatabricks(BaseChatModel):
"""Databricks chat model integration.
Setup:
Install ``langchain-databricks``.
.. code-block:: bash
pip install -U langchain-databricks
If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args completion params:
endpoint: str
Name of Databricks Model Serving endpoint to query.
target_uri: str
The target URI to use. Defaults to ``databricks``.
temperature: float
Sampling temperature. Higher values make the model more creative.
n: Optional[int]
The number of completion choices to generate.
stop: Optional[List[str]]
List of strings to stop generation at.
max_tokens: Optional[int]
Max number of tokens to generate.
extra_params: Optional[Dict[str, Any]]
Any extra parameters to pass to the endpoint.
Instantiate:
.. code-block:: python
from langchain_databricks import ChatDatabricks
llm = ChatDatabricks(
endpoint="databricks-meta-llama-3-1-405b-instruct",
temperature=0,
max_tokens=500,
)
Invoke:
.. code-block:: python
messages = [
("system", "You are a helpful translator. Translate the user sentence to French."),
("human", "I love programming."),
]
llm.invoke(messages)
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'prompt_tokens': 32,
'completion_tokens': 9,
'total_tokens': 41
},
id='run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0'
)
Stream:
.. code-block:: python
for chunk in llm.stream(messages):
print(chunk)
.. code-block:: python
content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
.. code-block:: python
stream = llm.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full
.. code-block:: python
AIMessageChunk(
content="J'adore la programmation.",
response_metadata={
'finish_reason': 'stop'
},
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
)
Async:
.. code-block:: python
await llm.ainvoke(messages)
# stream:
# async for chunk in llm.astream(messages)
# batch:
# await llm.abatch([messages])
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'prompt_tokens': 32,
'completion_tokens': 9,
'total_tokens': 41
},
id='run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0'
)
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetWeather',
'args': {
'location': 'Los Angeles, CA'
},
'id': 'call_ea0a6004-8e64-4ae8-a192-a40e295bfa24',
'type': 'tool_call'
}
]
To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information.
""" # noqa: E501
endpoint: str
"""Name of Databricks Model Serving endpoint to query."""
target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""
temperature: float = 0.0
"""Sampling temperature. Higher values make the model more creative."""
n: int = 1
"""The number of completion choices to generate."""
stop: Optional[List[str]] = None
"""List of strings to stop generation at."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"target_uri": self.target_uri,
"endpoint": self.endpoint,
"temperature": self.temperature,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(messages, stop, **kwargs)
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return self._convert_response_to_chat_result(resp)
def _prepare_inputs(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
data: Dict[str, Any] = {
"messages": [_convert_message_to_dict(msg) for msg in messages],
"temperature": self.temperature,
"n": self.n,
**self.extra_params,
**kwargs,
}
if stop := self.stop or stop:
data["stop"] = stop
if self.max_tokens is not None:
data["max_tokens"] = self.max_tokens
return data
def _convert_response_to_chat_result(
self, response: Mapping[str, Any]
) -> ChatResult:
generations = [
ChatGeneration(
message=_convert_dict_to_message(choice["message"]),
generation_info=choice.get("usage", {}),
)
for choice in response["choices"]
]
usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
if chunk["choices"]:
choice = chunk["choices"][0]
chunk_delta = choice["delta"]
if first_chunk_role is None:
first_chunk_role = chunk_delta.get("role")
chunk_message = _convert_dict_to_message_chunk(
chunk_delta, first_chunk_role
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk(
message=chunk_message, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
else:
# Handle the case where choices are empty if needed
continue
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Options are:
name of the tool (str): calls corresponding tool;
"auto": automatically selects a tool (including no tool);
"none": model does not generate any tool calls and instead must
generate a standard assistant message;
"required": the model picks the most relevant tool in tools and
must generate a tool call;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-databricks"
### Conversion function to convert Pydantic models to dictionaries and vice versa. ###
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"content": message.content}
# OpenAI supports "name" field in messages.
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
if id := message.id:
message_dict["id"] = id
if isinstance(message, ChatMessage):
return {"role": message.role, **message_dict}
elif isinstance(message, HumanMessage):
return {"role": "user", **message_dict}
elif isinstance(message, AIMessage):
if tool_calls := _get_tool_calls_from_ai_message(message):
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
# If tool calls present, content null value should be None not empty string.
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
return {"role": "assistant", **message_dict}
elif isinstance(message, SystemMessage):
return {"role": "system", **message_dict}
elif isinstance(message, ToolMessage):
return {
"role": "tool",
"tool_call_id": message.tool_call_id,
**message_dict,
}
elif (
isinstance(message, FunctionMessage)
or "function_call" in message.additional_kwargs
):
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
else:
raise ValueError(f"Got unknown message type: {type(message)}")
def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
tool_calls = [
{
"type": "function",
"id": tc["id"],
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["args"]),
},
}
for tc in message.tool_calls
]
invalid_tool_calls = [
{
"type": "function",
"id": tc["id"],
"function": {
"name": tc["name"],
"arguments": tc["args"],
},
}
for tc in message.invalid_tool_calls
]
if tool_calls or invalid_tool_calls:
return tool_calls + invalid_tool_calls
# Get tool calls from additional kwargs if present.
return [
{
k: v
for k, v in tool_call.items() # type: ignore[union-attr]
if k in {"id", "type", "function"}
}
for tool_call in message.additional_kwargs.get("tool_calls", [])
]
def _convert_dict_to_message(_dict: Dict) -> BaseMessage:
role = _dict["role"]
content = _dict.get("content")
content = content if content is not None else ""
if role == "user":
return HumanMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
elif role == "assistant":
additional_kwargs: Dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
else:
return ChatMessage(content=content, role=role)
def _convert_dict_to_message_chunk(
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict.get("content")
content = content if content is not None else ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "system":
return SystemMessageChunk(content=content)
elif role == "tool":
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
)
elif role == "assistant":
additional_kwargs: Dict = {}
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=tc["function"].get("name"),
args=tc["function"].get("arguments"),
id=tc.get("id"),
index=tc["index"],
)
for tc in raw_tool_calls
]
except KeyError:
pass
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_call_chunks=tool_call_chunks,
)
else:
return ChatMessageChunk(content=content, role=role)

View File

@ -1,91 +0,0 @@
from typing import Any, Dict, Iterator, List
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
from langchain_databricks.utils import get_deployment_client
class DatabricksEmbeddings(Embeddings, BaseModel):
"""Databricks embedding model integration.
Setup:
Install ``langchain-databricks``.
.. code-block:: bash
pip install -U langchain-databricks
If you are outside Databricks, set the Databricks workspace
hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args completion params:
endpoint: str
Name of Databricks Model Serving endpoint to query.
target_uri: str
The target URI to use. Defaults to ``databricks``.
query_params: Dict[str, str]
The parameters to use for queries.
documents_params: Dict[str, str]
The parameters to use for documents.
Instantiate:
.. code-block:: python
from langchain_databricks import DatabricksEmbeddings
embed = DatabricksEmbeddings(
endpoint="databricks-bge-large-en",
)
Embed single text:
.. code-block:: python
input_text = "The meaning of life is 42"
embed.embed_query(input_text)
.. code-block:: python
[
0.01605224609375,
-0.0298309326171875,
...
]
"""
endpoint: str
"""The endpoint to use."""
target_uri: str = "databricks"
"""The parameters to use for queries."""
query_params: Dict[str, Any] = {}
"""The parameters to use for documents."""
documents_params: Dict[str, Any] = {}
"""The target URI to use."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._client = get_deployment_client(self.target_uri)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts, params=self.documents_params)
def embed_query(self, text: str) -> List[float]:
return self._embed([text], params=self.query_params)[0]
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
embeddings: List[List[float]] = []
for txt in _chunk(texts, 20):
resp = self._client.predict(
endpoint=self.endpoint,
inputs={"input": txt, **params}, # type: ignore[arg-type]
)
embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]

View File

@ -1,101 +0,0 @@
from typing import Any, List, Union
from urllib.parse import urlparse
import numpy as np
def get_deployment_client(target_uri: str) -> Any:
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)
try:
from mlflow.deployments import get_deploy_client # type: ignore[import-untyped]
return get_deploy_client(target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
"Please run `pip install mlflow` to install "
"required dependencies."
) from e
# Utility function for Maximal Marginal Relevance (MMR) reranking.
# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
def maximal_marginal_relevance(
query_embedding: np.ndarray,
embedding_list: list,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[int]:
"""Calculate maximal marginal relevance.
Args:
query_embedding: Query embedding.
embedding_list: List of embeddings to select from.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
k: Number of Documents to return. Defaults to 4.
Returns:
List of indices of embeddings selected by maximal marginal relevance.
"""
if min(k, len(embedding_list)) <= 0:
return []
if query_embedding.ndim == 1:
query_embedding = np.expand_dims(query_embedding, axis=0)
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
most_similar = int(np.argmax(similarity_to_query))
idxs = [most_similar]
selected = np.array([embedding_list[most_similar]])
while len(idxs) < min(k, len(embedding_list)):
best_score = -np.inf
idx_to_add = -1
similarity_to_selected = cosine_similarity(embedding_list, selected)
for i, query_score in enumerate(similarity_to_query):
if i in idxs:
continue
redundant_score = max(similarity_to_selected[i])
equation_score = (
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
)
if equation_score > best_score:
best_score = equation_score
idx_to_add = i
idxs.append(idx_to_add)
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
return idxs
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
"""Row-wise cosine similarity between two equal-width matrices.
Raises:
ValueError: If the number of columns in X and Y are not the same.
"""
if len(X) == 0 or len(Y) == 0:
return np.array([])
X = np.array(X)
Y = np.array(Y)
if X.shape[1] != Y.shape[1]:
raise ValueError(
"Number of columns in X and Y must be the same. X has shape"
f"{X.shape} "
f"and Y has shape {Y.shape}."
)
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

View File

@ -1,837 +0,0 @@
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from enum import Enum
from functools import partial
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
)
import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VST, VectorStore
from langchain_databricks.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
class IndexType(str, Enum):
DIRECT_ACCESS = "DIRECT_ACCESS"
DELTA_SYNC = "DELTA_SYNC"
_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
_NON_MANAGED_EMB_ONLY_MSG = (
"`%s` is not supported for index with Databricks-managed embeddings."
)
class DatabricksVectorSearch(VectorStore):
"""Databricks vector store integration.
Setup:
Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages.
.. code-block:: bash
pip install -U langchain-databricks databricks-vectorsearch
If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html
If you are outside Databricks, set the Databricks workspace
hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args indexing params:
endpoint: The name of the Databricks Vector Search endpoint.
index_name: The name of the index to use. Format: "catalog.schema.index".
embedding: The embedding model.
Required for direct-access index or delta-sync index
with self-managed embeddings.
text_column: The name of the text column to use for the embeddings.
Required for direct-access index or delta-sync index
with self-managed embeddings.
Make sure the text column specified is in the index.
columns: The list of column names to get when doing the search.
Defaults to ``[primary_key, text_column]``.
Instantiate:
`DatabricksVectorSearch` supports two types of indexes:
* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.
* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.
Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes).
If you are using a delta-sync index with Databricks-managed embeddings:
.. code-block:: python
from langchain_databricks.vectorstores import DatabricksVectorSearch
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>"
)
If you are using a direct-access index or a delta-sync index with self-managed embeddings,
you also need to provide the embedding model and text column in your source table to
use for the embeddings:
.. code-block:: python
from langchain_openai import OpenAIEmbeddings
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
embedding=OpenAIEmbeddings(),
text_column="document_content"
)
Add Documents:
.. code-block:: python
from langchain_core.documents import Document
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
document_3 = Document(page_content="i will be deleted :(")
documents = [document_1, document_2, document_3]
ids = ["1", "2", "3"]
vector_store.add_documents(documents=documents, ids=ids)
Delete Documents:
.. code-block:: python
vector_store.delete(ids=["3"])
.. note::
The `delete` method is only supported for direct-access index.
Search:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1)
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'id': '2'}]
.. note:
By default, similarity search only returns the primary key and text column.
If you want to retrieve the custom metadata associated with the document,
pass the additional columns in the `columns` parameter when initializing the vector store.
.. code-block:: python
vector_store = DatabricksVectorSearch(
endpoint="<your-endpoint-name>",
index_name="<your-index-name>",
columns=["baz", "bar"],
)
vector_store.similarity_search(query="thud",k=1)
# Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}]
Search with filter:
.. code-block:: python
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
for doc in results:
print(f"* {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* thud [{'id': '2'}]
Search with score:
.. code-block:: python
results = vector_store.similarity_search_with_score(query="qux",k=1)
for doc, score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.748804] foo [{'id': '1'}]
Async:
.. code-block:: python
# add documents
await vector_store.aadd_documents(documents=documents, ids=ids)
# delete documents
await vector_store.adelete(ids=["3"])
# search
results = vector_store.asimilarity_search(query="thud",k=1)
# search with score
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
for doc,score in results:
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
.. code-block:: python
* [SIM=0.748807] foo [{'id': '1'}]
Use as Retriever:
.. code-block:: python
retriever = vector_store.as_retriever(
search_type="mmr",
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
)
retriever.invoke("thud")
.. code-block:: python
[Document(metadata={'id': '2'}, page_content='thud')]
""" # noqa: E501
def __init__(
self,
endpoint: str,
index_name: str,
embedding: Optional[Embeddings] = None,
text_column: Optional[str] = None,
columns: Optional[List[str]] = None,
):
try:
from databricks.vector_search.client import ( # type: ignore[import]
VectorSearchClient,
)
except ImportError as e:
raise ImportError(
"Could not import databricks-vectorsearch python package. "
"Please install it with `pip install databricks-vectorsearch`."
) from e
self.index = VectorSearchClient().get_index(endpoint, index_name)
self._index_details = IndexDetails(self.index)
_validate_embedding(embedding, self._index_details)
self._embeddings = embedding
self._text_column = _validate_and_get_text_column(
text_column, self._index_details
)
self._columns = _validate_and_get_return_columns(
columns or [], self._text_column, self._index_details
)
self._primary_key = self._index_details.primary_key
@property
def embeddings(self) -> Optional[Embeddings]:
"""Access the query embedding object if available."""
return self._embeddings
@classmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[Dict]] = None,
**kwargs: Any,
) -> VST:
raise NotImplementedError(
"`from_texts` is not supported. "
"Use `add_texts` to add to existing direct-access index."
)
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[Dict]] = None,
ids: Optional[List[Any]] = None,
**kwargs: Any,
) -> List[str]:
"""Add texts to the index.
.. note::
This method is only supported for a direct-access index.
Args:
texts: List of texts to add.
metadatas: List of metadata for each text. Defaults to None.
ids: List of ids for each text. Defaults to None.
If not provided, a random uuid will be generated for each text.
Returns:
List of ids from adding the texts into the index.
"""
if self._index_details.is_delta_sync_index():
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts")
# Wrap to list if input texts is a single string
if isinstance(texts, str):
texts = [texts]
texts = list(texts)
vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr]
ids = ids or [str(uuid.uuid4()) for _ in texts]
metadatas = metadatas or [{} for _ in texts]
updates = [
{
self._primary_key: id_,
self._text_column: text,
self._index_details.embedding_vector_column["name"]: vector,
**metadata,
}
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
]
upsert_resp = self.index.upsert(updates)
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
failed_ids = upsert_resp.get("result", dict()).get(
"failed_primary_keys", []
)
if upsert_resp.get("status") == "FAILURE":
logger.error("Failed to add texts to the index.")
else:
logger.warning("Some texts failed to be added to the index.")
return [id_ for id_ in ids if id_ not in failed_ids]
return ids
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
**kwargs: Any,
) -> List[str]:
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
"""Delete documents from the index.
.. note::
This method is only supported for a direct-access index.
Args:
ids: List of ids of documents to delete.
Returns:
True if successful.
"""
if self._index_details.is_delta_sync_index():
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete")
if ids is None:
raise ValueError("ids must be provided.")
self.index.delete(ids)
return True
def similarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to query.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
docs_with_score = self.similarity_search_with_score(
query=query,
k=k,
filter=filter,
query_type=query_type,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search, query, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query, along with scores.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding and score for each.
"""
if self._index_details.is_databricks_managed_embeddings():
query_text = query
query_vector = None
else:
# The value for `query_text` needs to be specified only for hybrid search.
if query_type is not None and query_type.upper() == "HYBRID":
query_text = query
else:
query_text = None
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
search_resp = self.index.similarity_search(
columns=self._columns,
query_text=query_text,
query_vector=query_vector,
filters=filter,
num_results=k,
query_type=query_type,
)
return self._parse_search_response(search_resp)
def _select_relevance_score_fn(self) -> Callable[[float], float]:
"""
Databricks Vector search uses a normalized score 1/(1+d) where d
is the L2 distance. Hence, we simply return the identity function.
"""
return lambda score: score
async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any
) -> List[Tuple[Document, float]]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_with_score, *args, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs most similar to embedding vector.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector"
)
docs_with_score = self.similarity_search_by_vector_with_score(
embedding=embedding,
k=k,
filter=filter,
query_type=query_type,
query=query,
**kwargs,
)
return [doc for doc, _ in docs_with_score]
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector_with_score(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
query: Optional[str] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs most similar to embedding vector, along with scores.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents most similar to the embedding and score for each.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score"
)
if query_type is not None and query_type.upper() == "HYBRID":
if query is None:
raise ValueError(
"A value for `query` must be specified for hybrid search."
)
query_text = query
else:
if query is not None:
raise ValueError(
(
"Cannot specify both `embedding` and "
'`query` unless `query_type="HYBRID"'
)
)
query_text = None
search_resp = self.index.similarity_search(
columns=self._columns,
query_vector=embedding,
query_text=query_text,
filters=filter,
num_results=k,
query_type=query_type,
)
return self._parse_search_response(search_resp)
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Dict[str, Any]] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search"
)
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
docs = self.max_marginal_relevance_search_by_vector(
query_vector,
k,
fetch_k,
lambda_mult=lambda_mult,
filter=filter,
query_type=query_type,
)
return docs
async def amax_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
# This is a temporary workaround to make the similarity search
# asynchronous. The proper solution is to make the similarity search
# asynchronous in the vector store implementations.
func = partial(
self.max_marginal_relevance_search,
query,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
**kwargs,
)
return await asyncio.get_event_loop().run_in_executor(None, func)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
filter: Optional[Any] = None,
*,
query_type: Optional[str] = None,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
.. note::
This method is not supported for index with Databricks-managed embeddings.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
filter: Filters to apply to the query. Defaults to None.
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._index_details.is_databricks_managed_embeddings():
raise NotImplementedError(
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector"
)
embedding_column = self._index_details.embedding_vector_column["name"]
search_resp = self.index.similarity_search(
columns=list(set(self._columns + [embedding_column])),
query_text=None,
query_vector=embedding,
filters=filter,
num_results=fetch_k,
query_type=query_type,
)
embeddings_result_index = (
search_resp.get("manifest").get("columns").index({"name": embedding_column})
)
embeddings = [
doc[embeddings_result_index]
for doc in search_resp.get("result").get("data_array")
]
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
embeddings,
k=k,
lambda_mult=lambda_mult,
)
ignore_cols: List = (
[embedding_column] if embedding_column not in self._columns else []
)
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def _parse_search_response(
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
) -> List[Tuple[Document, float]]:
"""Parse the search response into a list of Documents with score."""
if ignore_cols is None:
ignore_cols = []
columns = [
col["name"]
for col in search_resp.get("manifest", dict()).get("columns", [])
]
docs_with_score = []
for result in search_resp.get("result", dict()).get("data_array", []):
doc_id = result[columns.index(self._primary_key)]
text_content = result[columns.index(self._text_column)]
ignore_cols = [self._primary_key, self._text_column] + ignore_cols
metadata = {
col: value
for col, value in zip(columns[:-1], result[:-1])
if col not in ignore_cols
}
metadata[self._primary_key] = doc_id
score = result[-1]
doc = Document(page_content=text_content, metadata=metadata)
docs_with_score.append((doc, score))
return docs_with_score
def _validate_and_get_text_column(
text_column: Optional[str], index_details: IndexDetails
) -> str:
if index_details.is_databricks_managed_embeddings():
index_source_column: str = index_details.embedding_source_column["name"]
# check if input text column matches the source column of the index
if text_column is not None:
raise ValueError(
f"The index '{index_details.name}' has the source column configured as "
f"'{index_source_column}'. Do not pass the `text_column` parameter."
)
return index_source_column
else:
if text_column is None:
raise ValueError("The `text_column` parameter is required for this index.")
return text_column
def _validate_and_get_return_columns(
columns: List[str], text_column: str, index_details: IndexDetails
) -> List[str]:
"""
Get a list of columns to retrieve from the index.
If the index is direct-access index, validate the given columns against the schema.
"""
# add primary key column and source column if not in columns
if index_details.primary_key not in columns:
columns.append(index_details.primary_key)
if text_column and text_column not in columns:
columns.append(text_column)
# Validate specified columns are in the index
if index_details.is_direct_access_index() and (
index_schema := index_details.schema
):
if missing_columns := [c for c in columns if c not in index_schema]:
raise ValueError(
"Some columns specified in `columns` are not "
f"in the index schema: {missing_columns}"
)
return columns
def _validate_embedding(
embedding: Optional[Embeddings], index_details: IndexDetails
) -> None:
if index_details.is_databricks_managed_embeddings():
if embedding is not None:
raise ValueError(
f"The index '{index_details.name}' uses Databricks-managed embeddings. "
"Do not pass the `embedding` parameter when initializing vector store."
)
else:
if not embedding:
raise ValueError(
"The `embedding` parameter is required for a direct-access index "
"or delta-sync index with self-managed embedding."
)
_validate_embedding_dimension(embedding, index_details)
def _validate_embedding_dimension(
embeddings: Embeddings, index_details: IndexDetails
) -> None:
"""validate if the embedding dimension matches with the index's configuration."""
if index_embedding_dimension := index_details.embedding_vector_column.get(
"embedding_dimension"
):
# Infer the embedding dimension from the embedding function."""
actual_dimension = len(embeddings.embed_query("test"))
if actual_dimension != index_embedding_dimension:
raise ValueError(
f"The specified embedding model's dimension '{actual_dimension}' does "
f"not match with the index configuration '{index_embedding_dimension}'."
)
class IndexDetails:
"""An utility class to store the configuration details of an index."""
def __init__(self, index: Any):
self._index_details = index.describe()
@property
def name(self) -> str:
return self._index_details["name"]
@property
def schema(self) -> Optional[Dict]:
if self.is_direct_access_index():
schema_json = self.index_spec.get("schema_json")
if schema_json is not None:
return json.loads(schema_json)
return None
@property
def primary_key(self) -> str:
return self._index_details["primary_key"]
@property
def index_spec(self) -> Dict:
return (
self._index_details.get("delta_sync_index_spec", {})
if self.is_delta_sync_index()
else self._index_details.get("direct_access_index_spec", {})
)
@property
def embedding_vector_column(self) -> Dict:
if vector_columns := self.index_spec.get("embedding_vector_columns"):
return vector_columns[0]
return {}
@property
def embedding_source_column(self) -> Dict:
if source_columns := self.index_spec.get("embedding_source_columns"):
return source_columns[0]
return {}
def is_delta_sync_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value
def is_direct_access_index(self) -> bool:
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value
def is_databricks_managed_embeddings(self) -> bool:
return (
self.is_delta_sync_index()
and self.embedding_source_column.get("name") is not None
)

File diff suppressed because it is too large Load Diff

View File

@ -1,100 +0,0 @@
[tool.poetry]
name = "langchain-databricks"
version = "0.1.0"
description = "An integration package connecting Databricks and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/databricks"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22databricks%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
python = ">=3.8.1,<3.12"
langchain-core = "^0.2.0"
mlflow = ">=2.9"
# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
numpy = [
{version = ">=1.26.0", python = ">=3.12"},
{version = ">=1.24.0", python = "<3.12"},
]
scipy = [
{version = ">=1.11", python = ">=3.12"},
{version = "<2", python = "<3.12"}
]
databricks-vectorsearch = "^0.40"
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
pytest-socket = "^0.7.0"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.6"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@ -1,17 +0,0 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

View File

@ -1,27 +0,0 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -1,18 +0,0 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain, langchain_experimental, or langchain_community
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -1,7 +0,0 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -1,321 +0,0 @@
"""Test chat model integration."""
import json
from typing import Generator
from unittest import mock
import mlflow # type: ignore # noqa: F401
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_databricks.chat_models import (
ChatDatabricks,
_convert_dict_to_message,
_convert_dict_to_message_chunk,
_convert_message_to_dict,
)
_MOCK_CHAT_RESPONSE = {
"id": "chatcmpl_id",
"object": "chat.completion",
"created": 1721875529,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "To calculate the result of 36939 multiplied by 8922.4, "
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
}
_MOCK_STREAM_RESPONSE = [
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "36939"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "x"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "8922.4"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": " = "},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "329,511,111.6"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
},
]
@pytest.fixture(autouse=True)
def mock_client() -> Generator:
client = mock.MagicMock()
client.predict.return_value = _MOCK_CHAT_RESPONSE
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
yield
@pytest.fixture
def llm() -> ChatDatabricks:
return ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
)
def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
res = llm.invoke(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
]
)
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]
def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
res = llm.stream(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
]
)
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
class GetWeather(BaseModel):
"""Get the current weather in a given location"""
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
"""Get the current population in a given location"""
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
response = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
assert isinstance(response, AIMessage)
### Test data conversion functions ###
@pytest.mark.parametrize(
("role", "expected_output"),
[
("user", HumanMessage("foo")),
("system", SystemMessage("foo")),
("assistant", AIMessage("foo")),
("any_role", ChatMessage(content="foo", role="any_role")),
],
)
def test_convert_message(role: str, expected_output: BaseMessage) -> None:
message = {"role": role, "content": "foo"}
result = _convert_dict_to_message(message)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == message
def test_convert_message_with_tool_calls() -> None:
ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12"
tool_calls = [
{
"id": ID,
"type": "function",
"function": {
"name": "main__test__python_exec",
"arguments": '{"code": "result = 36939 * 8922.4"}',
},
}
]
message_with_tools = {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
"id": ID,
}
result = _convert_dict_to_message(message_with_tools)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls},
id=ID,
tool_calls=[
{
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
"id": ID,
"type": "tool_call",
}
],
)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == message_with_tools
@pytest.mark.parametrize(
("role", "expected_output"),
[
("user", HumanMessageChunk(content="foo")),
("system", SystemMessageChunk(content="foo")),
("assistant", AIMessageChunk(content="foo")),
("any_role", ChatMessageChunk(content="foo", role="any_role")),
],
)
def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None:
delta = {"role": role, "content": "foo"}
result = _convert_dict_to_message_chunk(delta, "default_role")
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == delta
def test_convert_message_chunk_with_tool_calls() -> None:
delta_with_tools = {
"role": "assistant",
"content": None,
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
}
result = _convert_dict_to_message_chunk(delta_with_tools, "role")
expected_output = AIMessageChunk(
content="",
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
id=None,
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
)
assert result == expected_output
def test_convert_tool_message_chunk() -> None:
delta = {
"role": "tool",
"content": "foo",
"tool_call_id": "tool_call_id",
"id": "some_id",
}
result = _convert_dict_to_message_chunk(delta, "default_role")
expected_output = ToolMessageChunk(
content="foo", id="some_id", tool_call_id="tool_call_id"
)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == delta
def test_convert_message_to_dict_function() -> None:
with pytest.raises(ValueError, match="Function messages are not supported"):
_convert_message_to_dict(FunctionMessage(content="", name="name"))

View File

@ -1,69 +0,0 @@
"""Test Together AI embeddings."""
from typing import Any, Dict, Generator
from unittest import mock
import pytest
from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped]
from langchain_databricks import DatabricksEmbeddings
def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {
"object": "list",
"data": [
{
"object": "embedding",
"embedding": list(range(1536)),
"index": 0,
}
for _ in inputs["input"]
],
"model": "text-embedding-3-small",
"usage": {"prompt_tokens": 8, "total_tokens": 8},
}
@pytest.fixture
def mock_client() -> Generator:
client = mock.MagicMock()
client.predict.side_effect = _mock_embeddings
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
yield client
@pytest.fixture
def embeddings() -> DatabricksEmbeddings:
return DatabricksEmbeddings(
endpoint="text-embedding-3-small",
documents_params={"fruit": "apple"},
query_params={"fruit": "banana"},
)
def test_embed_documents(
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
) -> None:
documents = ["foo"] * 30
output = embeddings.embed_documents(documents)
assert len(output) == 30
assert len(output[0]) == 1536
assert mock_client.predict.call_count == 2
assert all(
call_arg[1]["inputs"]["fruit"] == "apple"
for call_arg in mock_client().predict.call_args_list
)
def test_embed_query(
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
) -> None:
query = "foo bar"
output = embeddings.embed_query(query)
assert len(output) == 1536
mock_client.predict.assert_called_once()
assert mock_client.predict.call_args[1] == {
"endpoint": "text-embedding-3-small",
"inputs": {"input": [query], "fruit": "banana"},
}

View File

@ -1,12 +0,0 @@
from langchain_databricks import __all__
EXPECTED_ALL = [
"ChatDatabricks",
"DatabricksEmbeddings",
"DatabricksVectorSearch",
"__version__",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)

View File

@ -1,629 +0,0 @@
import uuid
from typing import Any, Dict, Generator, List, Optional, Set
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.embeddings import Embeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
INPUT_TEXTS = ["foo", "bar", "baz"]
DEFAULT_VECTOR_DIMENSION = 4
class FakeEmbeddings(Embeddings):
"""Fake embeddings functionality for testing."""
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
super().__init__()
self.dimension = dimension
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
"""Return simple embeddings."""
return [
[float(1.0)] * (self.dimension - 1) + [float(i)]
for i in range(len(embedding_texts))
]
def embed_query(self, text: str) -> List[float]:
"""Return simple embeddings."""
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
EMBEDDING_MODEL = FakeEmbeddings()
### Dummy similarity_search() Response ###
EXAMPLE_SEARCH_RESPONSE = {
"manifest": {
"column_count": 3,
"columns": [
{"name": "id"},
{"name": "text"},
{"name": "text_vector"},
{"name": "score"},
],
},
"result": {
"row_count": len(INPUT_TEXTS),
"data_array": sorted(
[
[str(uuid.uuid4()), s, e, 0.5]
for s, e in zip(
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
)
],
key=lambda x: x[2], # type: ignore
reverse=True,
),
},
"next_page_token": "",
}
### Dummy Indices ####
ENDPOINT_NAME = "test-endpoint"
DIRECT_ACCESS_INDEX = "test-direct-access-index"
DELTA_SYNC_INDEX = "test-delta-sync-index"
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
ALL_INDEX_NAMES = {
DIRECT_ACCESS_INDEX,
DELTA_SYNC_INDEX,
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
}
INDEX_DETAILS = {
DELTA_SYNC_INDEX: {
"name": DELTA_SYNC_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DELTA_SYNC",
"primary_key": "id",
"delta_sync_index_spec": {
"source_table": "ml.llm.source_table",
"pipeline_type": "CONTINUOUS",
"embedding_source_columns": [
{
"name": "text",
"embedding_model_endpoint_name": "openai-text-embedding",
}
],
},
},
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: {
"name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DELTA_SYNC",
"primary_key": "id",
"delta_sync_index_spec": {
"source_table": "ml.llm.source_table",
"pipeline_type": "CONTINUOUS",
"embedding_vector_columns": [
{
"name": "text_vector",
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
}
],
},
},
DIRECT_ACCESS_INDEX: {
"name": DIRECT_ACCESS_INDEX,
"endpoint_name": ENDPOINT_NAME,
"index_type": "DIRECT_ACCESS",
"primary_key": "id",
"direct_access_index_spec": {
"embedding_vector_columns": [
{
"name": "text_vector",
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
}
],
"schema_json": f"{{"
f'"{"id"}": "int", '
f'"feat1": "str", '
f'"feat2": "float", '
f'"text": "string", '
f'"{"text_vector"}": "array<float>"'
f"}}",
},
},
}
@pytest.fixture(autouse=True)
def mock_vs_client() -> Generator:
def _get_index(endpoint: str, index_name: str) -> MagicMock:
from databricks.vector_search.client import VectorSearchIndex # type: ignore
if endpoint != ENDPOINT_NAME:
raise ValueError(f"Unknown endpoint: {endpoint}")
index = MagicMock(spec=VectorSearchIndex)
index.describe.return_value = INDEX_DETAILS[index_name]
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
return index
mock_client = MagicMock()
mock_client.get_index.side_effect = _get_index
with mock.patch(
"databricks.vector_search.client.VectorSearchClient",
return_value=mock_client,
):
yield
def init_vector_search(
index_name: str, columns: Optional[List[str]] = None
) -> DatabricksVectorSearch:
kwargs: Dict[str, Any] = {
"endpoint": ENDPOINT_NAME,
"index_name": index_name,
"columns": columns,
}
if index_name != DELTA_SYNC_INDEX:
kwargs.update(
{
"embedding": EMBEDDING_MODEL,
"text_column": "text",
}
)
return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type]
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_init(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]
def test_init_fail_text_column_mismatch() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
text_column="some_other_column",
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_no_text_column(index_name: str) -> None:
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
embedding=EMBEDDING_MODEL,
)
def test_init_fail_columns_not_in_schema() -> None:
columns = ["some_random_column"]
with pytest.raises(ValueError, match="Some columns specified in `columns`"):
init_vector_search(DIRECT_ACCESS_INDEX, columns=columns)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_no_embedding(index_name: str) -> None:
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
)
def test_init_fail_embedding_already_specified_in_source() -> None:
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=DELTA_SYNC_INDEX,
embedding=EMBEDDING_MODEL,
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
with pytest.raises(
ValueError, match="embedding model's dimension '1000' does not match"
):
DatabricksVectorSearch(
endpoint=ENDPOINT_NAME,
index_name=index_name,
text_column="text",
embedding=FakeEmbeddings(1000),
)
def test_from_texts_not_supported() -> None:
with pytest.raises(NotImplementedError, match="`from_texts` is not supported"):
DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
with pytest.raises(
NotImplementedError,
match="`add_texts` is only supported for direct-access index.",
):
vectorsearch.add_texts(INPUT_TEXTS)
def is_valid_uuid(val: str) -> bool:
try:
uuid.UUID(str(val))
return True
except ValueError:
return False
def test_add_texts() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
ids = [idx for idx, i in enumerate(INPUT_TEXTS)]
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert added_ids == ids
def test_add_texts_handle_single_text() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS[0])
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
]
)
assert len(added_ids) == 1
assert is_valid_uuid(added_ids[0])
def test_add_texts_with_default_id() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
added_ids = vectorsearch.add_texts(INPUT_TEXTS)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
}
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert all([is_valid_uuid(id_) for id_ in added_ids])
def test_add_texts_with_metadata() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))]
added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas)
vectorsearch.index.upsert.assert_called_once_with(
[
{
"id": id_,
"text": text,
"text_vector": vector,
**metadata, # type: ignore[arg-type]
}
for text, vector, id_, metadata in zip(
INPUT_TEXTS, vectors, added_ids, metadatas
)
]
)
assert len(added_ids) == len(INPUT_TEXTS)
assert all([is_valid_uuid(id_) for id_ in added_ids])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_embeddings_property(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
assert vectorsearch.embeddings == EMBEDDING_MODEL
def test_delete() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
vectorsearch.delete(["some id"])
vectorsearch.index.delete.assert_called_once_with(["some id"])
def test_delete_fail_no_ids() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
with pytest.raises(ValueError, match="ids must be provided."):
vectorsearch.delete()
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
with pytest.raises(
NotImplementedError, match="`delete` is only supported for direct-access"
):
vectorsearch.delete(["some id"])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("query_type", [None, "ANN"])
def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
vectorsearch = init_vector_search(index_name)
query = "foo"
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type=query_type
)
if index_name == DELTA_SYNC_INDEX:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type=query_type,
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=None,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type=query_type,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_similarity_search_hybrid(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query = "foo"
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search(
query, k=limit, filter=filters, query_type="HYBRID"
)
if index_name == DELTA_SYNC_INDEX:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=None,
filters=filters,
num_results=limit,
query_type="HYBRID",
)
else:
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_text=query,
query_vector=EMBEDDING_MODEL.embed_query(query),
filters=filters,
num_results=limit,
query_type="HYBRID",
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
def test_similarity_search_both_filter_and_filters_passed() -> None:
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
query = "foo"
filter = {"some filter": True}
filters = {"some other filter": False}
vectorsearch.similarity_search(query, filter=filter, filters=filters)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=EMBEDDING_MODEL.embed_query(query),
# `filter` should prevail over `filters`
filters=filter,
num_results=4,
query_text=None,
query_type=None,
)
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
@pytest.mark.parametrize(
"columns, expected_columns",
[
(None, {"id"}),
(["id", "text", "text_vector"], {"text_vector", "id"}),
],
)
def test_mmr_search(
index_name: str, columns: Optional[List[str]], expected_columns: Set[str]
) -> None:
vectorsearch = init_vector_search(index_name, columns=columns)
query = INPUT_TEXTS[0]
filters = {"some filter": True}
limit = 1
search_result = vectorsearch.max_marginal_relevance_search(
query, k=limit, filters=filters
)
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_mmr_parameters(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query = INPUT_TEXTS[0]
limit = 1
fetch_k = 3
lambda_mult = 0.25
filters = {"some filter": True}
with patch(
"langchain_databricks.vectorstores.maximal_marginal_relevance"
) as mock_mmr:
mock_mmr.return_value = [2]
retriever = vectorsearch.as_retriever(
search_type="mmr",
search_kwargs={
"k": limit,
"fetch_k": fetch_k,
"lambda_mult": lambda_mult,
"filter": filters,
},
)
search_result = retriever.invoke(query)
mock_mmr.assert_called_once()
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k
assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters
assert len(search_result) == limit
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8])
def test_similarity_score_threshold(index_name: str, threshold: float) -> None:
query = INPUT_TEXTS[0]
limit = len(INPUT_TEXTS)
vectorsearch = init_vector_search(index_name)
retriever = vectorsearch.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"k": limit, "score_threshold": threshold},
)
search_result = retriever.invoke(query)
if threshold <= 0.5:
assert len(search_result) == len(INPUT_TEXTS)
else:
assert len(search_result) == 0
def test_standard_params() -> None:
vectorstore = init_vector_search(DIRECT_ACCESS_INDEX)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
"ls_embedding_provider": "FakeEmbeddings",
}
vectorstore = init_vector_search(DELTA_SYNC_INDEX)
retriever = vectorstore.as_retriever()
ls_params = retriever._get_ls_params()
assert ls_params == {
"ls_retriever_name": "vectorstore",
"ls_vector_store_provider": "DatabricksVectorSearch",
}
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
@pytest.mark.parametrize("query_type", [None, "ANN"])
def test_similarity_search_by_vector(
index_name: str, query_type: Optional[str]
) -> None:
vectorsearch = init_vector_search(index_name)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type=query_type
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type=query_type,
query_text=None,
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
search_result = vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
)
vectorsearch.index.similarity_search.assert_called_once_with(
columns=["id", "text"],
query_vector=query_embedding,
filters=filters,
num_results=limit,
query_type="HYBRID",
query_text="foo",
)
assert len(search_result) == len(INPUT_TEXTS)
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
assert all(["id" in d.metadata for d in search_result])
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
def test_similarity_search_empty_result(index_name: str) -> None:
vectorsearch = init_vector_search(index_name)
vectorsearch.index.similarity_search.return_value = {
"manifest": {
"column_count": 3,
"columns": [
{"name": "id"},
{"name": "text"},
{"name": "score"},
],
},
"result": {
"row_count": 0,
"data_array": [],
},
"next_page_token": "",
}
search_result = vectorsearch.similarity_search("foo")
assert len(search_result) == 0
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
query_embedding = EMBEDDING_MODEL.embed_query("foo")
filters = {"some filter": True}
limit = 7
with pytest.raises(
NotImplementedError, match="`similarity_search_by_vector` is not supported"
):
vectorsearch.similarity_search_by_vector(
query_embedding, k=limit, filters=filters
)