mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
databricks: mv to partner repo (#25788)
This commit is contained in:
parent
2e5c379632
commit
1023fbc98a
1
libs/partners/databricks/.gitignore
vendored
1
libs/partners/databricks/.gitignore
vendored
@ -1 +0,0 @@
|
|||||||
__pycache__
|
|
@ -1,21 +0,0 @@
|
|||||||
MIT License
|
|
||||||
|
|
||||||
Copyright (c) 2024 LangChain, Inc.
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
@ -1,62 +0,0 @@
|
|||||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
|
||||||
|
|
||||||
# Default target executed when no arguments are given to make.
|
|
||||||
all: help
|
|
||||||
|
|
||||||
# Define a variable for the test file path.
|
|
||||||
TEST_FILE ?= tests/unit_tests/
|
|
||||||
integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
|
||||||
|
|
||||||
|
|
||||||
# unit tests are run with the --disable-socket flag to prevent network calls
|
|
||||||
test tests:
|
|
||||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
|
||||||
|
|
||||||
# integration tests are run without the --disable-socket flag to allow network calls
|
|
||||||
integration_test integration_tests:
|
|
||||||
poetry run pytest $(TEST_FILE)
|
|
||||||
|
|
||||||
######################
|
|
||||||
# LINTING AND FORMATTING
|
|
||||||
######################
|
|
||||||
|
|
||||||
# Define a variable for Python and notebook files.
|
|
||||||
PYTHON_FILES=.
|
|
||||||
MYPY_CACHE=.mypy_cache
|
|
||||||
lint format: PYTHON_FILES=.
|
|
||||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/databricks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
|
||||||
lint_package: PYTHON_FILES=langchain_databricks
|
|
||||||
lint_tests: PYTHON_FILES=tests
|
|
||||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
|
||||||
|
|
||||||
lint lint_diff lint_package lint_tests:
|
|
||||||
poetry run ruff check .
|
|
||||||
poetry run ruff format $(PYTHON_FILES) --diff
|
|
||||||
poetry run ruff check --select I $(PYTHON_FILES)
|
|
||||||
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
|
||||||
|
|
||||||
format format_diff:
|
|
||||||
poetry run ruff format $(PYTHON_FILES)
|
|
||||||
poetry run ruff check --select I --fix $(PYTHON_FILES)
|
|
||||||
|
|
||||||
spell_check:
|
|
||||||
poetry run codespell --toml pyproject.toml
|
|
||||||
|
|
||||||
spell_fix:
|
|
||||||
poetry run codespell --toml pyproject.toml -w
|
|
||||||
|
|
||||||
check_imports: $(shell find langchain_databricks -name '*.py')
|
|
||||||
poetry run python ./scripts/check_imports.py $^
|
|
||||||
|
|
||||||
######################
|
|
||||||
# HELP
|
|
||||||
######################
|
|
||||||
|
|
||||||
help:
|
|
||||||
@echo '----'
|
|
||||||
@echo 'check_imports - check imports'
|
|
||||||
@echo 'format - run code formatters'
|
|
||||||
@echo 'lint - run linters'
|
|
||||||
@echo 'test - run unit tests'
|
|
||||||
@echo 'tests - run unit tests'
|
|
||||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
|
@ -1,24 +0,0 @@
|
|||||||
# langchain-databricks
|
|
||||||
|
|
||||||
This package contains the LangChain integration with Databricks
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -U langchain-databricks
|
|
||||||
```
|
|
||||||
|
|
||||||
And you should configure credentials by setting the following environment variables:
|
|
||||||
|
|
||||||
* TODO: fill this out
|
|
||||||
|
|
||||||
## Chat Models
|
|
||||||
|
|
||||||
`ChatDatabricks` class exposes chat models from Databricks.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from langchain_databricks import ChatDatabricks
|
|
||||||
|
|
||||||
llm = ChatDatabricks()
|
|
||||||
llm.invoke("Sing a ballad of LangChain.")
|
|
||||||
```
|
|
@ -1,19 +0,0 @@
|
|||||||
from importlib import metadata
|
|
||||||
|
|
||||||
from langchain_databricks.chat_models import ChatDatabricks
|
|
||||||
from langchain_databricks.embeddings import DatabricksEmbeddings
|
|
||||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
|
||||||
|
|
||||||
try:
|
|
||||||
__version__ = metadata.version(__package__)
|
|
||||||
except metadata.PackageNotFoundError:
|
|
||||||
# Case where package metadata is not available.
|
|
||||||
__version__ = ""
|
|
||||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ChatDatabricks",
|
|
||||||
"DatabricksEmbeddings",
|
|
||||||
"DatabricksVectorSearch",
|
|
||||||
"__version__",
|
|
||||||
]
|
|
@ -1,556 +0,0 @@
|
|||||||
"""Databricks chat models."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
|
||||||
from langchain_core.language_models.base import LanguageModelInput
|
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessage,
|
|
||||||
AIMessageChunk,
|
|
||||||
BaseMessage,
|
|
||||||
BaseMessageChunk,
|
|
||||||
ChatMessage,
|
|
||||||
ChatMessageChunk,
|
|
||||||
FunctionMessage,
|
|
||||||
HumanMessage,
|
|
||||||
HumanMessageChunk,
|
|
||||||
SystemMessage,
|
|
||||||
SystemMessageChunk,
|
|
||||||
ToolMessage,
|
|
||||||
ToolMessageChunk,
|
|
||||||
)
|
|
||||||
from langchain_core.messages.tool import tool_call_chunk
|
|
||||||
from langchain_core.output_parsers.openai_tools import (
|
|
||||||
make_invalid_tool_call,
|
|
||||||
parse_tool_call,
|
|
||||||
)
|
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
||||||
from langchain_core.pydantic_v1 import (
|
|
||||||
BaseModel,
|
|
||||||
Field,
|
|
||||||
PrivateAttr,
|
|
||||||
)
|
|
||||||
from langchain_core.runnables import Runnable
|
|
||||||
from langchain_core.tools import BaseTool
|
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
||||||
|
|
||||||
from langchain_databricks.utils import get_deployment_client
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatDatabricks(BaseChatModel):
|
|
||||||
"""Databricks chat model integration.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
Install ``langchain-databricks``.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install -U langchain-databricks
|
|
||||||
|
|
||||||
If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
|
||||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
|
||||||
|
|
||||||
Key init args — completion params:
|
|
||||||
endpoint: str
|
|
||||||
Name of Databricks Model Serving endpoint to query.
|
|
||||||
target_uri: str
|
|
||||||
The target URI to use. Defaults to ``databricks``.
|
|
||||||
temperature: float
|
|
||||||
Sampling temperature. Higher values make the model more creative.
|
|
||||||
n: Optional[int]
|
|
||||||
The number of completion choices to generate.
|
|
||||||
stop: Optional[List[str]]
|
|
||||||
List of strings to stop generation at.
|
|
||||||
max_tokens: Optional[int]
|
|
||||||
Max number of tokens to generate.
|
|
||||||
extra_params: Optional[Dict[str, Any]]
|
|
||||||
Any extra parameters to pass to the endpoint.
|
|
||||||
|
|
||||||
Instantiate:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_databricks import ChatDatabricks
|
|
||||||
llm = ChatDatabricks(
|
|
||||||
endpoint="databricks-meta-llama-3-1-405b-instruct",
|
|
||||||
temperature=0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
Invoke:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
("system", "You are a helpful translator. Translate the user sentence to French."),
|
|
||||||
("human", "I love programming."),
|
|
||||||
]
|
|
||||||
llm.invoke(messages)
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
AIMessage(
|
|
||||||
content="J'adore la programmation.",
|
|
||||||
response_metadata={
|
|
||||||
'prompt_tokens': 32,
|
|
||||||
'completion_tokens': 9,
|
|
||||||
'total_tokens': 41
|
|
||||||
},
|
|
||||||
id='run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0'
|
|
||||||
)
|
|
||||||
|
|
||||||
Stream:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
for chunk in llm.stream(messages):
|
|
||||||
print(chunk)
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
stream = llm.stream(messages)
|
|
||||||
full = next(stream)
|
|
||||||
for chunk in stream:
|
|
||||||
full += chunk
|
|
||||||
full
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
AIMessageChunk(
|
|
||||||
content="J'adore la programmation.",
|
|
||||||
response_metadata={
|
|
||||||
'finish_reason': 'stop'
|
|
||||||
},
|
|
||||||
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
|
|
||||||
)
|
|
||||||
|
|
||||||
Async:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
# stream:
|
|
||||||
# async for chunk in llm.astream(messages)
|
|
||||||
|
|
||||||
# batch:
|
|
||||||
# await llm.abatch([messages])
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
AIMessage(
|
|
||||||
content="J'adore la programmation.",
|
|
||||||
response_metadata={
|
|
||||||
'prompt_tokens': 32,
|
|
||||||
'completion_tokens': 9,
|
|
||||||
'total_tokens': 41
|
|
||||||
},
|
|
||||||
id='run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0'
|
|
||||||
)
|
|
||||||
|
|
||||||
Tool calling:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
|
|
||||||
class GetWeather(BaseModel):
|
|
||||||
'''Get the current weather in a given location'''
|
|
||||||
|
|
||||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
||||||
|
|
||||||
class GetPopulation(BaseModel):
|
|
||||||
'''Get the current population in a given location'''
|
|
||||||
|
|
||||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
|
||||||
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
|
||||||
ai_msg.tool_calls
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
'name': 'GetWeather',
|
|
||||||
'args': {
|
|
||||||
'location': 'Los Angeles, CA'
|
|
||||||
},
|
|
||||||
'id': 'call_ea0a6004-8e64-4ae8-a192-a40e295bfa24',
|
|
||||||
'type': 'tool_call'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information.
|
|
||||||
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
endpoint: str
|
|
||||||
"""Name of Databricks Model Serving endpoint to query."""
|
|
||||||
target_uri: str = "databricks"
|
|
||||||
"""The target URI to use. Defaults to ``databricks``."""
|
|
||||||
temperature: float = 0.0
|
|
||||||
"""Sampling temperature. Higher values make the model more creative."""
|
|
||||||
n: int = 1
|
|
||||||
"""The number of completion choices to generate."""
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
"""List of strings to stop generation at."""
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
"""The maximum number of tokens to generate."""
|
|
||||||
extra_params: dict = Field(default_factory=dict)
|
|
||||||
"""Any extra parameters to pass to the endpoint."""
|
|
||||||
_client: Any = PrivateAttr()
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._client = get_deployment_client(self.target_uri)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"target_uri": self.target_uri,
|
|
||||||
"endpoint": self.endpoint,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"n": self.n,
|
|
||||||
"stop": self.stop,
|
|
||||||
"max_tokens": self.max_tokens,
|
|
||||||
"extra_params": self.extra_params,
|
|
||||||
}
|
|
||||||
return params
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> ChatResult:
|
|
||||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
|
||||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
|
||||||
return self._convert_response_to_chat_result(resp)
|
|
||||||
|
|
||||||
def _prepare_inputs(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
data: Dict[str, Any] = {
|
|
||||||
"messages": [_convert_message_to_dict(msg) for msg in messages],
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"n": self.n,
|
|
||||||
**self.extra_params,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
if stop := self.stop or stop:
|
|
||||||
data["stop"] = stop
|
|
||||||
if self.max_tokens is not None:
|
|
||||||
data["max_tokens"] = self.max_tokens
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
def _convert_response_to_chat_result(
|
|
||||||
self, response: Mapping[str, Any]
|
|
||||||
) -> ChatResult:
|
|
||||||
generations = [
|
|
||||||
ChatGeneration(
|
|
||||||
message=_convert_dict_to_message(choice["message"]),
|
|
||||||
generation_info=choice.get("usage", {}),
|
|
||||||
)
|
|
||||||
for choice in response["choices"]
|
|
||||||
]
|
|
||||||
usage = response.get("usage", {})
|
|
||||||
return ChatResult(generations=generations, llm_output=usage)
|
|
||||||
|
|
||||||
def _stream(
|
|
||||||
self,
|
|
||||||
messages: List[BaseMessage],
|
|
||||||
stop: Optional[List[str]] = None,
|
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
|
||||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
|
||||||
first_chunk_role = None
|
|
||||||
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
|
|
||||||
if chunk["choices"]:
|
|
||||||
choice = chunk["choices"][0]
|
|
||||||
|
|
||||||
chunk_delta = choice["delta"]
|
|
||||||
if first_chunk_role is None:
|
|
||||||
first_chunk_role = chunk_delta.get("role")
|
|
||||||
|
|
||||||
chunk_message = _convert_dict_to_message_chunk(
|
|
||||||
chunk_delta, first_chunk_role
|
|
||||||
)
|
|
||||||
|
|
||||||
generation_info = {}
|
|
||||||
if finish_reason := choice.get("finish_reason"):
|
|
||||||
generation_info["finish_reason"] = finish_reason
|
|
||||||
if logprobs := choice.get("logprobs"):
|
|
||||||
generation_info["logprobs"] = logprobs
|
|
||||||
|
|
||||||
chunk = ChatGenerationChunk(
|
|
||||||
message=chunk_message, generation_info=generation_info or None
|
|
||||||
)
|
|
||||||
|
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(
|
|
||||||
chunk.text, chunk=chunk, logprobs=logprobs
|
|
||||||
)
|
|
||||||
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
# Handle the case where choices are empty if needed
|
|
||||||
continue
|
|
||||||
|
|
||||||
def bind_tools(
|
|
||||||
self,
|
|
||||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
|
||||||
*,
|
|
||||||
tool_choice: Optional[
|
|
||||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
|
||||||
] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
||||||
"""Bind tool-like objects to this chat model.
|
|
||||||
|
|
||||||
Assumes model is compatible with OpenAI tool-calling API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tools: A list of tool definitions to bind to this chat model.
|
|
||||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
|
||||||
models, callables, and BaseTools will be automatically converted to
|
|
||||||
their schema dictionary representation.
|
|
||||||
tool_choice: Which tool to require the model to call.
|
|
||||||
Options are:
|
|
||||||
name of the tool (str): calls corresponding tool;
|
|
||||||
"auto": automatically selects a tool (including no tool);
|
|
||||||
"none": model does not generate any tool calls and instead must
|
|
||||||
generate a standard assistant message;
|
|
||||||
"required": the model picks the most relevant tool in tools and
|
|
||||||
must generate a tool call;
|
|
||||||
|
|
||||||
or a dict of the form:
|
|
||||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
|
||||||
**kwargs: Any additional parameters to pass to the
|
|
||||||
:class:`~langchain.runnable.Runnable` constructor.
|
|
||||||
"""
|
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
|
||||||
if tool_choice:
|
|
||||||
if isinstance(tool_choice, str):
|
|
||||||
# tool_choice is a tool/function name
|
|
||||||
if tool_choice not in ("auto", "none", "required"):
|
|
||||||
tool_choice = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {"name": tool_choice},
|
|
||||||
}
|
|
||||||
elif isinstance(tool_choice, dict):
|
|
||||||
tool_names = [
|
|
||||||
formatted_tool["function"]["name"]
|
|
||||||
for formatted_tool in formatted_tools
|
|
||||||
]
|
|
||||||
if not any(
|
|
||||||
tool_name == tool_choice["function"]["name"]
|
|
||||||
for tool_name in tool_names
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Tool choice {tool_choice} was specified, but the only "
|
|
||||||
f"provided tools were {tool_names}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
|
||||||
f"Received: {tool_choice}"
|
|
||||||
)
|
|
||||||
kwargs["tool_choice"] = tool_choice
|
|
||||||
return super().bind(tools=formatted_tools, **kwargs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of chat model."""
|
|
||||||
return "chat-databricks"
|
|
||||||
|
|
||||||
|
|
||||||
### Conversion function to convert Pydantic models to dictionaries and vice versa. ###
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|
||||||
message_dict = {"content": message.content}
|
|
||||||
|
|
||||||
# OpenAI supports "name" field in messages.
|
|
||||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
|
||||||
message_dict["name"] = name
|
|
||||||
|
|
||||||
if id := message.id:
|
|
||||||
message_dict["id"] = id
|
|
||||||
|
|
||||||
if isinstance(message, ChatMessage):
|
|
||||||
return {"role": message.role, **message_dict}
|
|
||||||
elif isinstance(message, HumanMessage):
|
|
||||||
return {"role": "user", **message_dict}
|
|
||||||
elif isinstance(message, AIMessage):
|
|
||||||
if tool_calls := _get_tool_calls_from_ai_message(message):
|
|
||||||
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
|
|
||||||
# If tool calls present, content null value should be None not empty string.
|
|
||||||
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
|
|
||||||
return {"role": "assistant", **message_dict}
|
|
||||||
elif isinstance(message, SystemMessage):
|
|
||||||
return {"role": "system", **message_dict}
|
|
||||||
elif isinstance(message, ToolMessage):
|
|
||||||
return {
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": message.tool_call_id,
|
|
||||||
**message_dict,
|
|
||||||
}
|
|
||||||
elif (
|
|
||||||
isinstance(message, FunctionMessage)
|
|
||||||
or "function_call" in message.additional_kwargs
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"Function messages are not supported by Databricks. Please"
|
|
||||||
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Got unknown message type: {type(message)}")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
|
|
||||||
tool_calls = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"id": tc["id"],
|
|
||||||
"function": {
|
|
||||||
"name": tc["name"],
|
|
||||||
"arguments": json.dumps(tc["args"]),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in message.tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
invalid_tool_calls = [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"id": tc["id"],
|
|
||||||
"function": {
|
|
||||||
"name": tc["name"],
|
|
||||||
"arguments": tc["args"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in message.invalid_tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
if tool_calls or invalid_tool_calls:
|
|
||||||
return tool_calls + invalid_tool_calls
|
|
||||||
|
|
||||||
# Get tool calls from additional kwargs if present.
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
k: v
|
|
||||||
for k, v in tool_call.items() # type: ignore[union-attr]
|
|
||||||
if k in {"id", "type", "function"}
|
|
||||||
}
|
|
||||||
for tool_call in message.additional_kwargs.get("tool_calls", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message(_dict: Dict) -> BaseMessage:
|
|
||||||
role = _dict["role"]
|
|
||||||
content = _dict.get("content")
|
|
||||||
content = content if content is not None else ""
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
return HumanMessage(content=content)
|
|
||||||
elif role == "system":
|
|
||||||
return SystemMessage(content=content)
|
|
||||||
elif role == "assistant":
|
|
||||||
additional_kwargs: Dict = {}
|
|
||||||
tool_calls = []
|
|
||||||
invalid_tool_calls = []
|
|
||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
|
||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
|
||||||
for raw_tool_call in raw_tool_calls:
|
|
||||||
try:
|
|
||||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
|
||||||
except Exception as e:
|
|
||||||
invalid_tool_calls.append(
|
|
||||||
make_invalid_tool_call(raw_tool_call, str(e))
|
|
||||||
)
|
|
||||||
return AIMessage(
|
|
||||||
content=content,
|
|
||||||
additional_kwargs=additional_kwargs,
|
|
||||||
id=_dict.get("id"),
|
|
||||||
tool_calls=tool_calls,
|
|
||||||
invalid_tool_calls=invalid_tool_calls,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ChatMessage(content=content, role=role)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_dict_to_message_chunk(
|
|
||||||
_dict: Mapping[str, Any], default_role: str
|
|
||||||
) -> BaseMessageChunk:
|
|
||||||
role = _dict.get("role", default_role)
|
|
||||||
content = _dict.get("content")
|
|
||||||
content = content if content is not None else ""
|
|
||||||
|
|
||||||
if role == "user":
|
|
||||||
return HumanMessageChunk(content=content)
|
|
||||||
elif role == "system":
|
|
||||||
return SystemMessageChunk(content=content)
|
|
||||||
elif role == "tool":
|
|
||||||
return ToolMessageChunk(
|
|
||||||
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
|
|
||||||
)
|
|
||||||
elif role == "assistant":
|
|
||||||
additional_kwargs: Dict = {}
|
|
||||||
tool_call_chunks = []
|
|
||||||
if raw_tool_calls := _dict.get("tool_calls"):
|
|
||||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
|
||||||
try:
|
|
||||||
tool_call_chunks = [
|
|
||||||
tool_call_chunk(
|
|
||||||
name=tc["function"].get("name"),
|
|
||||||
args=tc["function"].get("arguments"),
|
|
||||||
id=tc.get("id"),
|
|
||||||
index=tc["index"],
|
|
||||||
)
|
|
||||||
for tc in raw_tool_calls
|
|
||||||
]
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
return AIMessageChunk(
|
|
||||||
content=content,
|
|
||||||
additional_kwargs=additional_kwargs,
|
|
||||||
id=_dict.get("id"),
|
|
||||||
tool_call_chunks=tool_call_chunks,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return ChatMessageChunk(content=content, role=role)
|
|
@ -1,91 +0,0 @@
|
|||||||
from typing import Any, Dict, Iterator, List
|
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
|
||||||
|
|
||||||
from langchain_databricks.utils import get_deployment_client
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksEmbeddings(Embeddings, BaseModel):
|
|
||||||
"""Databricks embedding model integration.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
Install ``langchain-databricks``.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install -U langchain-databricks
|
|
||||||
|
|
||||||
If you are outside Databricks, set the Databricks workspace
|
|
||||||
hostname and personal access token to environment variables:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
|
||||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
|
||||||
|
|
||||||
Key init args — completion params:
|
|
||||||
endpoint: str
|
|
||||||
Name of Databricks Model Serving endpoint to query.
|
|
||||||
target_uri: str
|
|
||||||
The target URI to use. Defaults to ``databricks``.
|
|
||||||
query_params: Dict[str, str]
|
|
||||||
The parameters to use for queries.
|
|
||||||
documents_params: Dict[str, str]
|
|
||||||
The parameters to use for documents.
|
|
||||||
|
|
||||||
Instantiate:
|
|
||||||
.. code-block:: python
|
|
||||||
from langchain_databricks import DatabricksEmbeddings
|
|
||||||
embed = DatabricksEmbeddings(
|
|
||||||
endpoint="databricks-bge-large-en",
|
|
||||||
)
|
|
||||||
|
|
||||||
Embed single text:
|
|
||||||
.. code-block:: python
|
|
||||||
input_text = "The meaning of life is 42"
|
|
||||||
embed.embed_query(input_text)
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
[
|
|
||||||
0.01605224609375,
|
|
||||||
-0.0298309326171875,
|
|
||||||
...
|
|
||||||
]
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
endpoint: str
|
|
||||||
"""The endpoint to use."""
|
|
||||||
target_uri: str = "databricks"
|
|
||||||
"""The parameters to use for queries."""
|
|
||||||
query_params: Dict[str, Any] = {}
|
|
||||||
"""The parameters to use for documents."""
|
|
||||||
documents_params: Dict[str, Any] = {}
|
|
||||||
"""The target URI to use."""
|
|
||||||
_client: Any = PrivateAttr()
|
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
self._client = get_deployment_client(self.target_uri)
|
|
||||||
|
|
||||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
||||||
return self._embed(texts, params=self.documents_params)
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
return self._embed([text], params=self.query_params)[0]
|
|
||||||
|
|
||||||
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
|
|
||||||
embeddings: List[List[float]] = []
|
|
||||||
for txt in _chunk(texts, 20):
|
|
||||||
resp = self._client.predict(
|
|
||||||
endpoint=self.endpoint,
|
|
||||||
inputs={"input": txt, **params}, # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
|
||||||
for i in range(0, len(texts), size):
|
|
||||||
yield texts[i : i + size]
|
|
@ -1,101 +0,0 @@
|
|||||||
from typing import Any, List, Union
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def get_deployment_client(target_uri: str) -> Any:
|
|
||||||
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
from mlflow.deployments import get_deploy_client # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
return get_deploy_client(target_uri)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Failed to create the client. "
|
|
||||||
"Please run `pip install mlflow` to install "
|
|
||||||
"required dependencies."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
# Utility function for Maximal Marginal Relevance (MMR) reranking.
|
|
||||||
# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency
|
|
||||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
|
||||||
|
|
||||||
|
|
||||||
def maximal_marginal_relevance(
|
|
||||||
query_embedding: np.ndarray,
|
|
||||||
embedding_list: list,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
k: int = 4,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Calculate maximal marginal relevance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding: Query embedding.
|
|
||||||
embedding_list: List of embeddings to select from.
|
|
||||||
lambda_mult: Number between 0 and 1 that determines the degree
|
|
||||||
of diversity among the results with 0 corresponding
|
|
||||||
to maximum diversity and 1 to minimum diversity.
|
|
||||||
Defaults to 0.5.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of indices of embeddings selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
if min(k, len(embedding_list)) <= 0:
|
|
||||||
return []
|
|
||||||
if query_embedding.ndim == 1:
|
|
||||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
|
||||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
|
||||||
most_similar = int(np.argmax(similarity_to_query))
|
|
||||||
idxs = [most_similar]
|
|
||||||
selected = np.array([embedding_list[most_similar]])
|
|
||||||
while len(idxs) < min(k, len(embedding_list)):
|
|
||||||
best_score = -np.inf
|
|
||||||
idx_to_add = -1
|
|
||||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
|
||||||
for i, query_score in enumerate(similarity_to_query):
|
|
||||||
if i in idxs:
|
|
||||||
continue
|
|
||||||
redundant_score = max(similarity_to_selected[i])
|
|
||||||
equation_score = (
|
|
||||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
|
||||||
)
|
|
||||||
if equation_score > best_score:
|
|
||||||
best_score = equation_score
|
|
||||||
idx_to_add = i
|
|
||||||
idxs.append(idx_to_add)
|
|
||||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
|
||||||
return idxs
|
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
|
||||||
"""Row-wise cosine similarity between two equal-width matrices.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the number of columns in X and Y are not the same.
|
|
||||||
"""
|
|
||||||
if len(X) == 0 or len(Y) == 0:
|
|
||||||
return np.array([])
|
|
||||||
|
|
||||||
X = np.array(X)
|
|
||||||
Y = np.array(Y)
|
|
||||||
if X.shape[1] != Y.shape[1]:
|
|
||||||
raise ValueError(
|
|
||||||
"Number of columns in X and Y must be the same. X has shape"
|
|
||||||
f"{X.shape} "
|
|
||||||
f"and Y has shape {Y.shape}."
|
|
||||||
)
|
|
||||||
|
|
||||||
X_norm = np.linalg.norm(X, axis=1)
|
|
||||||
Y_norm = np.linalg.norm(Y, axis=1)
|
|
||||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
|
||||||
with np.errstate(divide="ignore", invalid="ignore"):
|
|
||||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
|
||||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
|
||||||
return similarity
|
|
@ -1,837 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from enum import Enum
|
|
||||||
from functools import partial
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.vectorstores import VST, VectorStore
|
|
||||||
|
|
||||||
from langchain_databricks.utils import maximal_marginal_relevance
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class IndexType(str, Enum):
|
|
||||||
DIRECT_ACCESS = "DIRECT_ACCESS"
|
|
||||||
DELTA_SYNC = "DELTA_SYNC"
|
|
||||||
|
|
||||||
|
|
||||||
_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
|
|
||||||
_NON_MANAGED_EMB_ONLY_MSG = (
|
|
||||||
"`%s` is not supported for index with Databricks-managed embeddings."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DatabricksVectorSearch(VectorStore):
|
|
||||||
"""Databricks vector store integration.
|
|
||||||
|
|
||||||
Setup:
|
|
||||||
Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages.
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
pip install -U langchain-databricks databricks-vectorsearch
|
|
||||||
|
|
||||||
If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html
|
|
||||||
|
|
||||||
If you are outside Databricks, set the Databricks workspace
|
|
||||||
hostname and personal access token to environment variables:
|
|
||||||
|
|
||||||
.. code-block:: bash
|
|
||||||
|
|
||||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
|
||||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
|
||||||
|
|
||||||
Key init args — indexing params:
|
|
||||||
|
|
||||||
endpoint: The name of the Databricks Vector Search endpoint.
|
|
||||||
index_name: The name of the index to use. Format: "catalog.schema.index".
|
|
||||||
embedding: The embedding model.
|
|
||||||
Required for direct-access index or delta-sync index
|
|
||||||
with self-managed embeddings.
|
|
||||||
text_column: The name of the text column to use for the embeddings.
|
|
||||||
Required for direct-access index or delta-sync index
|
|
||||||
with self-managed embeddings.
|
|
||||||
Make sure the text column specified is in the index.
|
|
||||||
columns: The list of column names to get when doing the search.
|
|
||||||
Defaults to ``[primary_key, text_column]``.
|
|
||||||
|
|
||||||
Instantiate:
|
|
||||||
|
|
||||||
`DatabricksVectorSearch` supports two types of indexes:
|
|
||||||
|
|
||||||
* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.
|
|
||||||
|
|
||||||
* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.
|
|
||||||
|
|
||||||
Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes).
|
|
||||||
|
|
||||||
If you are using a delta-sync index with Databricks-managed embeddings:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
|
||||||
|
|
||||||
vector_store = DatabricksVectorSearch(
|
|
||||||
endpoint="<your-endpoint-name>",
|
|
||||||
index_name="<your-index-name>"
|
|
||||||
)
|
|
||||||
|
|
||||||
If you are using a direct-access index or a delta-sync index with self-managed embeddings,
|
|
||||||
you also need to provide the embedding model and text column in your source table to
|
|
||||||
use for the embeddings:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
|
|
||||||
vector_store = DatabricksVectorSearch(
|
|
||||||
endpoint="<your-endpoint-name>",
|
|
||||||
index_name="<your-index-name>",
|
|
||||||
embedding=OpenAIEmbeddings(),
|
|
||||||
text_column="document_content"
|
|
||||||
)
|
|
||||||
|
|
||||||
Add Documents:
|
|
||||||
.. code-block:: python
|
|
||||||
from langchain_core.documents import Document
|
|
||||||
|
|
||||||
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
|
|
||||||
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
|
|
||||||
document_3 = Document(page_content="i will be deleted :(")
|
|
||||||
documents = [document_1, document_2, document_3]
|
|
||||||
ids = ["1", "2", "3"]
|
|
||||||
vector_store.add_documents(documents=documents, ids=ids)
|
|
||||||
|
|
||||||
Delete Documents:
|
|
||||||
.. code-block:: python
|
|
||||||
vector_store.delete(ids=["3"])
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
The `delete` method is only supported for direct-access index.
|
|
||||||
|
|
||||||
Search:
|
|
||||||
.. code-block:: python
|
|
||||||
results = vector_store.similarity_search(query="thud",k=1)
|
|
||||||
for doc in results:
|
|
||||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
|
||||||
.. code-block:: python
|
|
||||||
* thud [{'id': '2'}]
|
|
||||||
|
|
||||||
.. note:
|
|
||||||
|
|
||||||
By default, similarity search only returns the primary key and text column.
|
|
||||||
If you want to retrieve the custom metadata associated with the document,
|
|
||||||
pass the additional columns in the `columns` parameter when initializing the vector store.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
vector_store = DatabricksVectorSearch(
|
|
||||||
endpoint="<your-endpoint-name>",
|
|
||||||
index_name="<your-index-name>",
|
|
||||||
columns=["baz", "bar"],
|
|
||||||
)
|
|
||||||
|
|
||||||
vector_store.similarity_search(query="thud",k=1)
|
|
||||||
# Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}]
|
|
||||||
|
|
||||||
Search with filter:
|
|
||||||
.. code-block:: python
|
|
||||||
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
|
|
||||||
for doc in results:
|
|
||||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
|
||||||
.. code-block:: python
|
|
||||||
* thud [{'id': '2'}]
|
|
||||||
|
|
||||||
Search with score:
|
|
||||||
.. code-block:: python
|
|
||||||
results = vector_store.similarity_search_with_score(query="qux",k=1)
|
|
||||||
for doc, score in results:
|
|
||||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
|
||||||
.. code-block:: python
|
|
||||||
* [SIM=0.748804] foo [{'id': '1'}]
|
|
||||||
|
|
||||||
Async:
|
|
||||||
.. code-block:: python
|
|
||||||
# add documents
|
|
||||||
await vector_store.aadd_documents(documents=documents, ids=ids)
|
|
||||||
# delete documents
|
|
||||||
await vector_store.adelete(ids=["3"])
|
|
||||||
# search
|
|
||||||
results = vector_store.asimilarity_search(query="thud",k=1)
|
|
||||||
# search with score
|
|
||||||
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
|
|
||||||
for doc,score in results:
|
|
||||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
|
||||||
.. code-block:: python
|
|
||||||
* [SIM=0.748807] foo [{'id': '1'}]
|
|
||||||
|
|
||||||
Use as Retriever:
|
|
||||||
.. code-block:: python
|
|
||||||
retriever = vector_store.as_retriever(
|
|
||||||
search_type="mmr",
|
|
||||||
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
|
|
||||||
)
|
|
||||||
retriever.invoke("thud")
|
|
||||||
.. code-block:: python
|
|
||||||
[Document(metadata={'id': '2'}, page_content='thud')]
|
|
||||||
""" # noqa: E501
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
endpoint: str,
|
|
||||||
index_name: str,
|
|
||||||
embedding: Optional[Embeddings] = None,
|
|
||||||
text_column: Optional[str] = None,
|
|
||||||
columns: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from databricks.vector_search.client import ( # type: ignore[import]
|
|
||||||
VectorSearchClient,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import databricks-vectorsearch python package. "
|
|
||||||
"Please install it with `pip install databricks-vectorsearch`."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
self.index = VectorSearchClient().get_index(endpoint, index_name)
|
|
||||||
self._index_details = IndexDetails(self.index)
|
|
||||||
|
|
||||||
_validate_embedding(embedding, self._index_details)
|
|
||||||
self._embeddings = embedding
|
|
||||||
self._text_column = _validate_and_get_text_column(
|
|
||||||
text_column, self._index_details
|
|
||||||
)
|
|
||||||
self._columns = _validate_and_get_return_columns(
|
|
||||||
columns or [], self._text_column, self._index_details
|
|
||||||
)
|
|
||||||
self._primary_key = self._index_details.primary_key
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embeddings(self) -> Optional[Embeddings]:
|
|
||||||
"""Access the query embedding object if available."""
|
|
||||||
return self._embeddings
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_texts(
|
|
||||||
cls: Type[VST],
|
|
||||||
texts: List[str],
|
|
||||||
embedding: Embeddings,
|
|
||||||
metadatas: Optional[List[Dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> VST:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"`from_texts` is not supported. "
|
|
||||||
"Use `add_texts` to add to existing direct-access index."
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[Dict]] = None,
|
|
||||||
ids: Optional[List[Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
"""Add texts to the index.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
This method is only supported for a direct-access index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: List of texts to add.
|
|
||||||
metadatas: List of metadata for each text. Defaults to None.
|
|
||||||
ids: List of ids for each text. Defaults to None.
|
|
||||||
If not provided, a random uuid will be generated for each text.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ids from adding the texts into the index.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_delta_sync_index():
|
|
||||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts")
|
|
||||||
|
|
||||||
# Wrap to list if input texts is a single string
|
|
||||||
if isinstance(texts, str):
|
|
||||||
texts = [texts]
|
|
||||||
texts = list(texts)
|
|
||||||
vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr]
|
|
||||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
|
||||||
metadatas = metadatas or [{} for _ in texts]
|
|
||||||
|
|
||||||
updates = [
|
|
||||||
{
|
|
||||||
self._primary_key: id_,
|
|
||||||
self._text_column: text,
|
|
||||||
self._index_details.embedding_vector_column["name"]: vector,
|
|
||||||
**metadata,
|
|
||||||
}
|
|
||||||
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
|
|
||||||
]
|
|
||||||
|
|
||||||
upsert_resp = self.index.upsert(updates)
|
|
||||||
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
|
|
||||||
failed_ids = upsert_resp.get("result", dict()).get(
|
|
||||||
"failed_primary_keys", []
|
|
||||||
)
|
|
||||||
if upsert_resp.get("status") == "FAILURE":
|
|
||||||
logger.error("Failed to add texts to the index.")
|
|
||||||
else:
|
|
||||||
logger.warning("Some texts failed to be added to the index.")
|
|
||||||
return [id_ for id_ in ids if id_ not in failed_ids]
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
async def aadd_texts(
|
|
||||||
self,
|
|
||||||
texts: Iterable[str],
|
|
||||||
metadatas: Optional[List[dict]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[str]:
|
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
|
||||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
|
|
||||||
"""Delete documents from the index.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
This method is only supported for a direct-access index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of ids of documents to delete.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_delta_sync_index():
|
|
||||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete")
|
|
||||||
|
|
||||||
if ids is None:
|
|
||||||
raise ValueError("ids must be provided.")
|
|
||||||
self.index.delete(ids)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def similarity_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return docs most similar to query.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the embedding.
|
|
||||||
"""
|
|
||||||
docs_with_score = self.similarity_search_with_score(
|
|
||||||
query=query,
|
|
||||||
k=k,
|
|
||||||
filter=filter,
|
|
||||||
query_type=query_type,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return [doc for doc, _ in docs_with_score]
|
|
||||||
|
|
||||||
async def asimilarity_search(
|
|
||||||
self, query: str, k: int = 4, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
# This is a temporary workaround to make the similarity search
|
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
|
||||||
# asynchronous in the vector store implementations.
|
|
||||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def similarity_search_with_score(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs most similar to query, along with scores.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the embedding and score for each.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_databricks_managed_embeddings():
|
|
||||||
query_text = query
|
|
||||||
query_vector = None
|
|
||||||
else:
|
|
||||||
# The value for `query_text` needs to be specified only for hybrid search.
|
|
||||||
if query_type is not None and query_type.upper() == "HYBRID":
|
|
||||||
query_text = query
|
|
||||||
else:
|
|
||||||
query_text = None
|
|
||||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
|
||||||
|
|
||||||
search_resp = self.index.similarity_search(
|
|
||||||
columns=self._columns,
|
|
||||||
query_text=query_text,
|
|
||||||
query_vector=query_vector,
|
|
||||||
filters=filter,
|
|
||||||
num_results=k,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
return self._parse_search_response(search_resp)
|
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
|
||||||
"""
|
|
||||||
Databricks Vector search uses a normalized score 1/(1+d) where d
|
|
||||||
is the L2 distance. Hence, we simply return the identity function.
|
|
||||||
"""
|
|
||||||
return lambda score: score
|
|
||||||
|
|
||||||
async def asimilarity_search_with_score(
|
|
||||||
self, *args: Any, **kwargs: Any
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
# This is a temporary workaround to make the similarity search
|
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
|
||||||
# asynchronous in the vector store implementations.
|
|
||||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def similarity_search_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
filter: Optional[Any] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
query: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return docs most similar to embedding vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the embedding.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_databricks_managed_embeddings():
|
|
||||||
raise NotImplementedError(
|
|
||||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector"
|
|
||||||
)
|
|
||||||
|
|
||||||
docs_with_score = self.similarity_search_by_vector_with_score(
|
|
||||||
embedding=embedding,
|
|
||||||
k=k,
|
|
||||||
filter=filter,
|
|
||||||
query_type=query_type,
|
|
||||||
query=query,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return [doc for doc, _ in docs_with_score]
|
|
||||||
|
|
||||||
async def asimilarity_search_by_vector(
|
|
||||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
|
||||||
) -> List[Document]:
|
|
||||||
# This is a temporary workaround to make the similarity search
|
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
|
||||||
# asynchronous in the vector store implementations.
|
|
||||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def similarity_search_by_vector_with_score(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
filter: Optional[Any] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
query: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Return docs most similar to embedding vector, along with scores.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
This method is not supported for index with Databricks-managed embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Documents most similar to the embedding and score for each.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_databricks_managed_embeddings():
|
|
||||||
raise NotImplementedError(
|
|
||||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score"
|
|
||||||
)
|
|
||||||
|
|
||||||
if query_type is not None and query_type.upper() == "HYBRID":
|
|
||||||
if query is None:
|
|
||||||
raise ValueError(
|
|
||||||
"A value for `query` must be specified for hybrid search."
|
|
||||||
)
|
|
||||||
query_text = query
|
|
||||||
else:
|
|
||||||
if query is not None:
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
"Cannot specify both `embedding` and "
|
|
||||||
'`query` unless `query_type="HYBRID"'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
query_text = None
|
|
||||||
|
|
||||||
search_resp = self.index.similarity_search(
|
|
||||||
columns=self._columns,
|
|
||||||
query_vector=embedding,
|
|
||||||
query_text=query_text,
|
|
||||||
filters=filter,
|
|
||||||
num_results=k,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
return self._parse_search_response(search_resp)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
filter: Optional[Dict[str, Any]] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
This method is not supported for index with Databricks-managed embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Text to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
lambda_mult: Number between 0 and 1 that determines the degree
|
|
||||||
of diversity among the results with 0 corresponding
|
|
||||||
to maximum diversity and 1 to minimum diversity.
|
|
||||||
Defaults to 0.5.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
Returns:
|
|
||||||
List of Documents selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_databricks_managed_embeddings():
|
|
||||||
raise NotImplementedError(
|
|
||||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search"
|
|
||||||
)
|
|
||||||
|
|
||||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
|
||||||
docs = self.max_marginal_relevance_search_by_vector(
|
|
||||||
query_vector,
|
|
||||||
k,
|
|
||||||
fetch_k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
filter=filter,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
async def amax_marginal_relevance_search(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
# This is a temporary workaround to make the similarity search
|
|
||||||
# asynchronous. The proper solution is to make the similarity search
|
|
||||||
# asynchronous in the vector store implementations.
|
|
||||||
func = partial(
|
|
||||||
self.max_marginal_relevance_search,
|
|
||||||
query,
|
|
||||||
k=k,
|
|
||||||
fetch_k=fetch_k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
|
||||||
|
|
||||||
def max_marginal_relevance_search_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
filter: Optional[Any] = None,
|
|
||||||
*,
|
|
||||||
query_type: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
"""Return docs selected using the maximal marginal relevance.
|
|
||||||
|
|
||||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
|
||||||
among selected documents.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
This method is not supported for index with Databricks-managed embeddings.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding: Embedding to look up documents similar to.
|
|
||||||
k: Number of Documents to return. Defaults to 4.
|
|
||||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
|
||||||
lambda_mult: Number between 0 and 1 that determines the degree
|
|
||||||
of diversity among the results with 0 corresponding
|
|
||||||
to maximum diversity and 1 to minimum diversity.
|
|
||||||
Defaults to 0.5.
|
|
||||||
filter: Filters to apply to the query. Defaults to None.
|
|
||||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
|
||||||
Returns:
|
|
||||||
List of Documents selected by maximal marginal relevance.
|
|
||||||
"""
|
|
||||||
if self._index_details.is_databricks_managed_embeddings():
|
|
||||||
raise NotImplementedError(
|
|
||||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector"
|
|
||||||
)
|
|
||||||
|
|
||||||
embedding_column = self._index_details.embedding_vector_column["name"]
|
|
||||||
search_resp = self.index.similarity_search(
|
|
||||||
columns=list(set(self._columns + [embedding_column])),
|
|
||||||
query_text=None,
|
|
||||||
query_vector=embedding,
|
|
||||||
filters=filter,
|
|
||||||
num_results=fetch_k,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
embeddings_result_index = (
|
|
||||||
search_resp.get("manifest").get("columns").index({"name": embedding_column})
|
|
||||||
)
|
|
||||||
embeddings = [
|
|
||||||
doc[embeddings_result_index]
|
|
||||||
for doc in search_resp.get("result").get("data_array")
|
|
||||||
]
|
|
||||||
|
|
||||||
mmr_selected = maximal_marginal_relevance(
|
|
||||||
np.array(embedding, dtype=np.float32),
|
|
||||||
embeddings,
|
|
||||||
k=k,
|
|
||||||
lambda_mult=lambda_mult,
|
|
||||||
)
|
|
||||||
|
|
||||||
ignore_cols: List = (
|
|
||||||
[embedding_column] if embedding_column not in self._columns else []
|
|
||||||
)
|
|
||||||
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
|
|
||||||
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
|
|
||||||
return selected_results
|
|
||||||
|
|
||||||
async def amax_marginal_relevance_search_by_vector(
|
|
||||||
self,
|
|
||||||
embedding: List[float],
|
|
||||||
k: int = 4,
|
|
||||||
fetch_k: int = 20,
|
|
||||||
lambda_mult: float = 0.5,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> List[Document]:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _parse_search_response(
|
|
||||||
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
|
|
||||||
) -> List[Tuple[Document, float]]:
|
|
||||||
"""Parse the search response into a list of Documents with score."""
|
|
||||||
if ignore_cols is None:
|
|
||||||
ignore_cols = []
|
|
||||||
|
|
||||||
columns = [
|
|
||||||
col["name"]
|
|
||||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
|
||||||
]
|
|
||||||
docs_with_score = []
|
|
||||||
for result in search_resp.get("result", dict()).get("data_array", []):
|
|
||||||
doc_id = result[columns.index(self._primary_key)]
|
|
||||||
text_content = result[columns.index(self._text_column)]
|
|
||||||
ignore_cols = [self._primary_key, self._text_column] + ignore_cols
|
|
||||||
metadata = {
|
|
||||||
col: value
|
|
||||||
for col, value in zip(columns[:-1], result[:-1])
|
|
||||||
if col not in ignore_cols
|
|
||||||
}
|
|
||||||
metadata[self._primary_key] = doc_id
|
|
||||||
score = result[-1]
|
|
||||||
doc = Document(page_content=text_content, metadata=metadata)
|
|
||||||
docs_with_score.append((doc, score))
|
|
||||||
return docs_with_score
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_get_text_column(
|
|
||||||
text_column: Optional[str], index_details: IndexDetails
|
|
||||||
) -> str:
|
|
||||||
if index_details.is_databricks_managed_embeddings():
|
|
||||||
index_source_column: str = index_details.embedding_source_column["name"]
|
|
||||||
# check if input text column matches the source column of the index
|
|
||||||
if text_column is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The index '{index_details.name}' has the source column configured as "
|
|
||||||
f"'{index_source_column}'. Do not pass the `text_column` parameter."
|
|
||||||
)
|
|
||||||
return index_source_column
|
|
||||||
else:
|
|
||||||
if text_column is None:
|
|
||||||
raise ValueError("The `text_column` parameter is required for this index.")
|
|
||||||
return text_column
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_get_return_columns(
|
|
||||||
columns: List[str], text_column: str, index_details: IndexDetails
|
|
||||||
) -> List[str]:
|
|
||||||
"""
|
|
||||||
Get a list of columns to retrieve from the index.
|
|
||||||
|
|
||||||
If the index is direct-access index, validate the given columns against the schema.
|
|
||||||
"""
|
|
||||||
# add primary key column and source column if not in columns
|
|
||||||
if index_details.primary_key not in columns:
|
|
||||||
columns.append(index_details.primary_key)
|
|
||||||
if text_column and text_column not in columns:
|
|
||||||
columns.append(text_column)
|
|
||||||
|
|
||||||
# Validate specified columns are in the index
|
|
||||||
if index_details.is_direct_access_index() and (
|
|
||||||
index_schema := index_details.schema
|
|
||||||
):
|
|
||||||
if missing_columns := [c for c in columns if c not in index_schema]:
|
|
||||||
raise ValueError(
|
|
||||||
"Some columns specified in `columns` are not "
|
|
||||||
f"in the index schema: {missing_columns}"
|
|
||||||
)
|
|
||||||
return columns
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_embedding(
|
|
||||||
embedding: Optional[Embeddings], index_details: IndexDetails
|
|
||||||
) -> None:
|
|
||||||
if index_details.is_databricks_managed_embeddings():
|
|
||||||
if embedding is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"The index '{index_details.name}' uses Databricks-managed embeddings. "
|
|
||||||
"Do not pass the `embedding` parameter when initializing vector store."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if not embedding:
|
|
||||||
raise ValueError(
|
|
||||||
"The `embedding` parameter is required for a direct-access index "
|
|
||||||
"or delta-sync index with self-managed embedding."
|
|
||||||
)
|
|
||||||
_validate_embedding_dimension(embedding, index_details)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_embedding_dimension(
|
|
||||||
embeddings: Embeddings, index_details: IndexDetails
|
|
||||||
) -> None:
|
|
||||||
"""validate if the embedding dimension matches with the index's configuration."""
|
|
||||||
if index_embedding_dimension := index_details.embedding_vector_column.get(
|
|
||||||
"embedding_dimension"
|
|
||||||
):
|
|
||||||
# Infer the embedding dimension from the embedding function."""
|
|
||||||
actual_dimension = len(embeddings.embed_query("test"))
|
|
||||||
if actual_dimension != index_embedding_dimension:
|
|
||||||
raise ValueError(
|
|
||||||
f"The specified embedding model's dimension '{actual_dimension}' does "
|
|
||||||
f"not match with the index configuration '{index_embedding_dimension}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IndexDetails:
|
|
||||||
"""An utility class to store the configuration details of an index."""
|
|
||||||
|
|
||||||
def __init__(self, index: Any):
|
|
||||||
self._index_details = index.describe()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._index_details["name"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def schema(self) -> Optional[Dict]:
|
|
||||||
if self.is_direct_access_index():
|
|
||||||
schema_json = self.index_spec.get("schema_json")
|
|
||||||
if schema_json is not None:
|
|
||||||
return json.loads(schema_json)
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def primary_key(self) -> str:
|
|
||||||
return self._index_details["primary_key"]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def index_spec(self) -> Dict:
|
|
||||||
return (
|
|
||||||
self._index_details.get("delta_sync_index_spec", {})
|
|
||||||
if self.is_delta_sync_index()
|
|
||||||
else self._index_details.get("direct_access_index_spec", {})
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_vector_column(self) -> Dict:
|
|
||||||
if vector_columns := self.index_spec.get("embedding_vector_columns"):
|
|
||||||
return vector_columns[0]
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def embedding_source_column(self) -> Dict:
|
|
||||||
if source_columns := self.index_spec.get("embedding_source_columns"):
|
|
||||||
return source_columns[0]
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def is_delta_sync_index(self) -> bool:
|
|
||||||
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value
|
|
||||||
|
|
||||||
def is_direct_access_index(self) -> bool:
|
|
||||||
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value
|
|
||||||
|
|
||||||
def is_databricks_managed_embeddings(self) -> bool:
|
|
||||||
return (
|
|
||||||
self.is_delta_sync_index()
|
|
||||||
and self.embedding_source_column.get("name") is not None
|
|
||||||
)
|
|
2525
libs/partners/databricks/poetry.lock
generated
2525
libs/partners/databricks/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,100 +0,0 @@
|
|||||||
[tool.poetry]
|
|
||||||
name = "langchain-databricks"
|
|
||||||
version = "0.1.0"
|
|
||||||
description = "An integration package connecting Databricks and LangChain"
|
|
||||||
authors = []
|
|
||||||
readme = "README.md"
|
|
||||||
repository = "https://github.com/langchain-ai/langchain"
|
|
||||||
license = "MIT"
|
|
||||||
|
|
||||||
[tool.poetry.urls]
|
|
||||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/databricks"
|
|
||||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22databricks%3D%3D0%22&expanded=true"
|
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
|
||||||
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
|
|
||||||
python = ">=3.8.1,<3.12"
|
|
||||||
langchain-core = "^0.2.0"
|
|
||||||
mlflow = ">=2.9"
|
|
||||||
|
|
||||||
# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
|
|
||||||
numpy = [
|
|
||||||
{version = ">=1.26.0", python = ">=3.12"},
|
|
||||||
{version = ">=1.24.0", python = "<3.12"},
|
|
||||||
]
|
|
||||||
scipy = [
|
|
||||||
{version = ">=1.11", python = ">=3.12"},
|
|
||||||
{version = "<2", python = "<3.12"}
|
|
||||||
]
|
|
||||||
databricks-vectorsearch = "^0.40"
|
|
||||||
|
|
||||||
[tool.poetry.group.test]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test.dependencies]
|
|
||||||
pytest = "^7.4.3"
|
|
||||||
pytest-asyncio = "^0.23.2"
|
|
||||||
pytest-socket = "^0.7.0"
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.codespell.dependencies]
|
|
||||||
codespell = "^2.2.6"
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.test_integration.dependencies]
|
|
||||||
|
|
||||||
[tool.poetry.group.lint]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.lint.dependencies]
|
|
||||||
ruff = "^0.5"
|
|
||||||
|
|
||||||
[tool.poetry.group.typing.dependencies]
|
|
||||||
mypy = "^1.10"
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.poetry.group.dev]
|
|
||||||
optional = true
|
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
|
||||||
langchain-core = { path = "../../core", develop = true }
|
|
||||||
|
|
||||||
[tool.ruff.lint]
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle
|
|
||||||
"F", # pyflakes
|
|
||||||
"I", # isort
|
|
||||||
"T201", # print
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
disallow_untyped_defs = "True"
|
|
||||||
|
|
||||||
[tool.coverage.run]
|
|
||||||
omit = ["tests/*"]
|
|
||||||
|
|
||||||
[build-system]
|
|
||||||
requires = ["poetry-core>=1.0.0"]
|
|
||||||
build-backend = "poetry.core.masonry.api"
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
# --strict-markers will raise errors on unknown marks.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
|
||||||
#
|
|
||||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
|
||||||
# --strict-config any warnings encountered while parsing the `pytest`
|
|
||||||
# section of the configuration file raise errors.
|
|
||||||
#
|
|
||||||
# https://github.com/tophat/syrupy
|
|
||||||
addopts = "--strict-markers --strict-config --durations=5"
|
|
||||||
# Registering custom markers.
|
|
||||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
|
||||||
markers = [
|
|
||||||
"compile: mark placeholder test used to compile integration tests without running them",
|
|
||||||
]
|
|
||||||
asyncio_mode = "auto"
|
|
@ -1,17 +0,0 @@
|
|||||||
import sys
|
|
||||||
import traceback
|
|
||||||
from importlib.machinery import SourceFileLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
files = sys.argv[1:]
|
|
||||||
has_failure = False
|
|
||||||
for file in files:
|
|
||||||
try:
|
|
||||||
SourceFileLoader("x", file).load_module()
|
|
||||||
except Exception:
|
|
||||||
has_failure = True
|
|
||||||
print(file) # noqa: T201
|
|
||||||
traceback.print_exc()
|
|
||||||
print() # noqa: T201
|
|
||||||
|
|
||||||
sys.exit(1 if has_failure else 0)
|
|
@ -1,27 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
|
||||||
# in tracked files within a Git repository.
|
|
||||||
#
|
|
||||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
|
||||||
|
|
||||||
# Check if a path argument is provided
|
|
||||||
if [ $# -ne 1 ]; then
|
|
||||||
echo "Usage: $0 /path/to/repository"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
repository_path="$1"
|
|
||||||
|
|
||||||
# Search for lines matching the pattern within the specified repository
|
|
||||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
|
||||||
|
|
||||||
# Check if any matching lines were found
|
|
||||||
if [ -n "$result" ]; then
|
|
||||||
echo "ERROR: The following lines need to be updated:"
|
|
||||||
echo "$result"
|
|
||||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
|
||||||
echo "For example, replace 'from pydantic import BaseModel'"
|
|
||||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
@ -1,18 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
# Initialize a variable to keep track of errors
|
|
||||||
errors=0
|
|
||||||
|
|
||||||
# make sure not importing from langchain, langchain_experimental, or langchain_community
|
|
||||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
|
||||||
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
|
||||||
|
|
||||||
# Decide on an exit status based on the errors
|
|
||||||
if [ "$errors" -gt 0 ]; then
|
|
||||||
exit 1
|
|
||||||
else
|
|
||||||
exit 0
|
|
||||||
fi
|
|
@ -1,7 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.compile
|
|
||||||
def test_placeholder() -> None:
|
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
|
||||||
pass
|
|
@ -1,321 +0,0 @@
|
|||||||
"""Test chat model integration."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Generator
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import mlflow # type: ignore # noqa: F401
|
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import (
|
|
||||||
AIMessage,
|
|
||||||
AIMessageChunk,
|
|
||||||
BaseMessage,
|
|
||||||
ChatMessage,
|
|
||||||
ChatMessageChunk,
|
|
||||||
FunctionMessage,
|
|
||||||
HumanMessage,
|
|
||||||
HumanMessageChunk,
|
|
||||||
SystemMessage,
|
|
||||||
SystemMessageChunk,
|
|
||||||
ToolMessageChunk,
|
|
||||||
)
|
|
||||||
from langchain_core.messages.tool import ToolCallChunk
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
||||||
|
|
||||||
from langchain_databricks.chat_models import (
|
|
||||||
ChatDatabricks,
|
|
||||||
_convert_dict_to_message,
|
|
||||||
_convert_dict_to_message_chunk,
|
|
||||||
_convert_message_to_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
_MOCK_CHAT_RESPONSE = {
|
|
||||||
"id": "chatcmpl_id",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1721875529,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
|
||||||
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
|
||||||
},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
|
||||||
}
|
|
||||||
|
|
||||||
_MOCK_STREAM_RESPONSE = [
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": "36939"},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": "x"},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": "8922.4"},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": " = "},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
|
||||||
"finish_reason": None,
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"created": 1721877054,
|
|
||||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {"role": "assistant", "content": ""},
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"logprobs": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def mock_client() -> Generator:
|
|
||||||
client = mock.MagicMock()
|
|
||||||
client.predict.return_value = _MOCK_CHAT_RESPONSE
|
|
||||||
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
|
|
||||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def llm() -> ChatDatabricks:
|
|
||||||
return ChatDatabricks(
|
|
||||||
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
|
|
||||||
res = llm.invoke(
|
|
||||||
[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "36939 * 8922.4"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
|
|
||||||
res = llm.stream(
|
|
||||||
[
|
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{"role": "user", "content": "36939 * 8922.4"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
|
|
||||||
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
|
|
||||||
|
|
||||||
|
|
||||||
def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
|
|
||||||
class GetWeather(BaseModel):
|
|
||||||
"""Get the current weather in a given location"""
|
|
||||||
|
|
||||||
location: str = Field(
|
|
||||||
..., description="The city and state, e.g. San Francisco, CA"
|
|
||||||
)
|
|
||||||
|
|
||||||
class GetPopulation(BaseModel):
|
|
||||||
"""Get the current population in a given location"""
|
|
||||||
|
|
||||||
location: str = Field(
|
|
||||||
..., description="The city and state, e.g. San Francisco, CA"
|
|
||||||
)
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
|
||||||
response = llm_with_tools.invoke(
|
|
||||||
"Which city is hotter today and which is bigger: LA or NY?"
|
|
||||||
)
|
|
||||||
assert isinstance(response, AIMessage)
|
|
||||||
|
|
||||||
|
|
||||||
### Test data conversion functions ###
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("role", "expected_output"),
|
|
||||||
[
|
|
||||||
("user", HumanMessage("foo")),
|
|
||||||
("system", SystemMessage("foo")),
|
|
||||||
("assistant", AIMessage("foo")),
|
|
||||||
("any_role", ChatMessage(content="foo", role="any_role")),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_convert_message(role: str, expected_output: BaseMessage) -> None:
|
|
||||||
message = {"role": role, "content": "foo"}
|
|
||||||
result = _convert_dict_to_message(message)
|
|
||||||
assert result == expected_output
|
|
||||||
|
|
||||||
# convert back
|
|
||||||
dict_result = _convert_message_to_dict(result)
|
|
||||||
assert dict_result == message
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_message_with_tool_calls() -> None:
|
|
||||||
ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12"
|
|
||||||
tool_calls = [
|
|
||||||
{
|
|
||||||
"id": ID,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "main__test__python_exec",
|
|
||||||
"arguments": '{"code": "result = 36939 * 8922.4"}',
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
message_with_tools = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": None,
|
|
||||||
"tool_calls": tool_calls,
|
|
||||||
"id": ID,
|
|
||||||
}
|
|
||||||
result = _convert_dict_to_message(message_with_tools)
|
|
||||||
expected_output = AIMessage(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={"tool_calls": tool_calls},
|
|
||||||
id=ID,
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
|
||||||
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
|
||||||
"id": ID,
|
|
||||||
"type": "tool_call",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
assert result == expected_output
|
|
||||||
|
|
||||||
# convert back
|
|
||||||
dict_result = _convert_message_to_dict(result)
|
|
||||||
assert dict_result == message_with_tools
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
("role", "expected_output"),
|
|
||||||
[
|
|
||||||
("user", HumanMessageChunk(content="foo")),
|
|
||||||
("system", SystemMessageChunk(content="foo")),
|
|
||||||
("assistant", AIMessageChunk(content="foo")),
|
|
||||||
("any_role", ChatMessageChunk(content="foo", role="any_role")),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None:
|
|
||||||
delta = {"role": role, "content": "foo"}
|
|
||||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
|
||||||
assert result == expected_output
|
|
||||||
|
|
||||||
# convert back
|
|
||||||
dict_result = _convert_message_to_dict(result)
|
|
||||||
assert dict_result == delta
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_message_chunk_with_tool_calls() -> None:
|
|
||||||
delta_with_tools = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": None,
|
|
||||||
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
|
||||||
}
|
|
||||||
result = _convert_dict_to_message_chunk(delta_with_tools, "role")
|
|
||||||
expected_output = AIMessageChunk(
|
|
||||||
content="",
|
|
||||||
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
|
||||||
id=None,
|
|
||||||
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
|
||||||
)
|
|
||||||
assert result == expected_output
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_tool_message_chunk() -> None:
|
|
||||||
delta = {
|
|
||||||
"role": "tool",
|
|
||||||
"content": "foo",
|
|
||||||
"tool_call_id": "tool_call_id",
|
|
||||||
"id": "some_id",
|
|
||||||
}
|
|
||||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
|
||||||
expected_output = ToolMessageChunk(
|
|
||||||
content="foo", id="some_id", tool_call_id="tool_call_id"
|
|
||||||
)
|
|
||||||
assert result == expected_output
|
|
||||||
|
|
||||||
# convert back
|
|
||||||
dict_result = _convert_message_to_dict(result)
|
|
||||||
assert dict_result == delta
|
|
||||||
|
|
||||||
|
|
||||||
def test_convert_message_to_dict_function() -> None:
|
|
||||||
with pytest.raises(ValueError, match="Function messages are not supported"):
|
|
||||||
_convert_message_to_dict(FunctionMessage(content="", name="name"))
|
|
@ -1,69 +0,0 @@
|
|||||||
"""Test Together AI embeddings."""
|
|
||||||
|
|
||||||
from typing import Any, Dict, Generator
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
from langchain_databricks import DatabricksEmbeddings
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"object": "list",
|
|
||||||
"data": [
|
|
||||||
{
|
|
||||||
"object": "embedding",
|
|
||||||
"embedding": list(range(1536)),
|
|
||||||
"index": 0,
|
|
||||||
}
|
|
||||||
for _ in inputs["input"]
|
|
||||||
],
|
|
||||||
"model": "text-embedding-3-small",
|
|
||||||
"usage": {"prompt_tokens": 8, "total_tokens": 8},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_client() -> Generator:
|
|
||||||
client = mock.MagicMock()
|
|
||||||
client.predict.side_effect = _mock_embeddings
|
|
||||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def embeddings() -> DatabricksEmbeddings:
|
|
||||||
return DatabricksEmbeddings(
|
|
||||||
endpoint="text-embedding-3-small",
|
|
||||||
documents_params={"fruit": "apple"},
|
|
||||||
query_params={"fruit": "banana"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_documents(
|
|
||||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
|
||||||
) -> None:
|
|
||||||
documents = ["foo"] * 30
|
|
||||||
output = embeddings.embed_documents(documents)
|
|
||||||
assert len(output) == 30
|
|
||||||
assert len(output[0]) == 1536
|
|
||||||
assert mock_client.predict.call_count == 2
|
|
||||||
assert all(
|
|
||||||
call_arg[1]["inputs"]["fruit"] == "apple"
|
|
||||||
for call_arg in mock_client().predict.call_args_list
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_embed_query(
|
|
||||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
|
||||||
) -> None:
|
|
||||||
query = "foo bar"
|
|
||||||
output = embeddings.embed_query(query)
|
|
||||||
assert len(output) == 1536
|
|
||||||
mock_client.predict.assert_called_once()
|
|
||||||
assert mock_client.predict.call_args[1] == {
|
|
||||||
"endpoint": "text-embedding-3-small",
|
|
||||||
"inputs": {"input": [query], "fruit": "banana"},
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
from langchain_databricks import __all__
|
|
||||||
|
|
||||||
EXPECTED_ALL = [
|
|
||||||
"ChatDatabricks",
|
|
||||||
"DatabricksEmbeddings",
|
|
||||||
"DatabricksVectorSearch",
|
|
||||||
"__version__",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_imports() -> None:
|
|
||||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
|
@ -1,629 +0,0 @@
|
|||||||
import uuid
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Set
|
|
||||||
from unittest import mock
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
|
|
||||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
|
||||||
|
|
||||||
INPUT_TEXTS = ["foo", "bar", "baz"]
|
|
||||||
DEFAULT_VECTOR_DIMENSION = 4
|
|
||||||
|
|
||||||
|
|
||||||
class FakeEmbeddings(Embeddings):
|
|
||||||
"""Fake embeddings functionality for testing."""
|
|
||||||
|
|
||||||
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
|
|
||||||
super().__init__()
|
|
||||||
self.dimension = dimension
|
|
||||||
|
|
||||||
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
|
|
||||||
"""Return simple embeddings."""
|
|
||||||
return [
|
|
||||||
[float(1.0)] * (self.dimension - 1) + [float(i)]
|
|
||||||
for i in range(len(embedding_texts))
|
|
||||||
]
|
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
|
||||||
"""Return simple embeddings."""
|
|
||||||
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
|
|
||||||
|
|
||||||
|
|
||||||
EMBEDDING_MODEL = FakeEmbeddings()
|
|
||||||
|
|
||||||
|
|
||||||
### Dummy similarity_search() Response ###
|
|
||||||
EXAMPLE_SEARCH_RESPONSE = {
|
|
||||||
"manifest": {
|
|
||||||
"column_count": 3,
|
|
||||||
"columns": [
|
|
||||||
{"name": "id"},
|
|
||||||
{"name": "text"},
|
|
||||||
{"name": "text_vector"},
|
|
||||||
{"name": "score"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"row_count": len(INPUT_TEXTS),
|
|
||||||
"data_array": sorted(
|
|
||||||
[
|
|
||||||
[str(uuid.uuid4()), s, e, 0.5]
|
|
||||||
for s, e in zip(
|
|
||||||
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
|
||||||
)
|
|
||||||
],
|
|
||||||
key=lambda x: x[2], # type: ignore
|
|
||||||
reverse=True,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"next_page_token": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
### Dummy Indices ####
|
|
||||||
|
|
||||||
ENDPOINT_NAME = "test-endpoint"
|
|
||||||
DIRECT_ACCESS_INDEX = "test-direct-access-index"
|
|
||||||
DELTA_SYNC_INDEX = "test-delta-sync-index"
|
|
||||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
|
|
||||||
ALL_INDEX_NAMES = {
|
|
||||||
DIRECT_ACCESS_INDEX,
|
|
||||||
DELTA_SYNC_INDEX,
|
|
||||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
|
||||||
}
|
|
||||||
|
|
||||||
INDEX_DETAILS = {
|
|
||||||
DELTA_SYNC_INDEX: {
|
|
||||||
"name": DELTA_SYNC_INDEX,
|
|
||||||
"endpoint_name": ENDPOINT_NAME,
|
|
||||||
"index_type": "DELTA_SYNC",
|
|
||||||
"primary_key": "id",
|
|
||||||
"delta_sync_index_spec": {
|
|
||||||
"source_table": "ml.llm.source_table",
|
|
||||||
"pipeline_type": "CONTINUOUS",
|
|
||||||
"embedding_source_columns": [
|
|
||||||
{
|
|
||||||
"name": "text",
|
|
||||||
"embedding_model_endpoint_name": "openai-text-embedding",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: {
|
|
||||||
"name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
|
||||||
"endpoint_name": ENDPOINT_NAME,
|
|
||||||
"index_type": "DELTA_SYNC",
|
|
||||||
"primary_key": "id",
|
|
||||||
"delta_sync_index_spec": {
|
|
||||||
"source_table": "ml.llm.source_table",
|
|
||||||
"pipeline_type": "CONTINUOUS",
|
|
||||||
"embedding_vector_columns": [
|
|
||||||
{
|
|
||||||
"name": "text_vector",
|
|
||||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DIRECT_ACCESS_INDEX: {
|
|
||||||
"name": DIRECT_ACCESS_INDEX,
|
|
||||||
"endpoint_name": ENDPOINT_NAME,
|
|
||||||
"index_type": "DIRECT_ACCESS",
|
|
||||||
"primary_key": "id",
|
|
||||||
"direct_access_index_spec": {
|
|
||||||
"embedding_vector_columns": [
|
|
||||||
{
|
|
||||||
"name": "text_vector",
|
|
||||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"schema_json": f"{{"
|
|
||||||
f'"{"id"}": "int", '
|
|
||||||
f'"feat1": "str", '
|
|
||||||
f'"feat2": "float", '
|
|
||||||
f'"text": "string", '
|
|
||||||
f'"{"text_vector"}": "array<float>"'
|
|
||||||
f"}}",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def mock_vs_client() -> Generator:
|
|
||||||
def _get_index(endpoint: str, index_name: str) -> MagicMock:
|
|
||||||
from databricks.vector_search.client import VectorSearchIndex # type: ignore
|
|
||||||
|
|
||||||
if endpoint != ENDPOINT_NAME:
|
|
||||||
raise ValueError(f"Unknown endpoint: {endpoint}")
|
|
||||||
|
|
||||||
index = MagicMock(spec=VectorSearchIndex)
|
|
||||||
index.describe.return_value = INDEX_DETAILS[index_name]
|
|
||||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
|
||||||
return index
|
|
||||||
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client.get_index.side_effect = _get_index
|
|
||||||
with mock.patch(
|
|
||||||
"databricks.vector_search.client.VectorSearchClient",
|
|
||||||
return_value=mock_client,
|
|
||||||
):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
def init_vector_search(
|
|
||||||
index_name: str, columns: Optional[List[str]] = None
|
|
||||||
) -> DatabricksVectorSearch:
|
|
||||||
kwargs: Dict[str, Any] = {
|
|
||||||
"endpoint": ENDPOINT_NAME,
|
|
||||||
"index_name": index_name,
|
|
||||||
"columns": columns,
|
|
||||||
}
|
|
||||||
if index_name != DELTA_SYNC_INDEX:
|
|
||||||
kwargs.update(
|
|
||||||
{
|
|
||||||
"embedding": EMBEDDING_MODEL,
|
|
||||||
"text_column": "text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
|
||||||
def test_init(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_fail_text_column_mismatch() -> None:
|
|
||||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
|
|
||||||
DatabricksVectorSearch(
|
|
||||||
endpoint=ENDPOINT_NAME,
|
|
||||||
index_name=DELTA_SYNC_INDEX,
|
|
||||||
text_column="some_other_column",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_init_fail_no_text_column(index_name: str) -> None:
|
|
||||||
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
|
|
||||||
DatabricksVectorSearch(
|
|
||||||
endpoint=ENDPOINT_NAME,
|
|
||||||
index_name=index_name,
|
|
||||||
embedding=EMBEDDING_MODEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_fail_columns_not_in_schema() -> None:
|
|
||||||
columns = ["some_random_column"]
|
|
||||||
with pytest.raises(ValueError, match="Some columns specified in `columns`"):
|
|
||||||
init_vector_search(DIRECT_ACCESS_INDEX, columns=columns)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_init_fail_no_embedding(index_name: str) -> None:
|
|
||||||
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
|
|
||||||
DatabricksVectorSearch(
|
|
||||||
endpoint=ENDPOINT_NAME,
|
|
||||||
index_name=index_name,
|
|
||||||
text_column="text",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_fail_embedding_already_specified_in_source() -> None:
|
|
||||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
|
|
||||||
DatabricksVectorSearch(
|
|
||||||
endpoint=ENDPOINT_NAME,
|
|
||||||
index_name=DELTA_SYNC_INDEX,
|
|
||||||
embedding=EMBEDDING_MODEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError, match="embedding model's dimension '1000' does not match"
|
|
||||||
):
|
|
||||||
DatabricksVectorSearch(
|
|
||||||
endpoint=ENDPOINT_NAME,
|
|
||||||
index_name=index_name,
|
|
||||||
text_column="text",
|
|
||||||
embedding=FakeEmbeddings(1000),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_texts_not_supported() -> None:
|
|
||||||
with pytest.raises(NotImplementedError, match="`from_texts` is not supported"):
|
|
||||||
DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
|
||||||
def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
with pytest.raises(
|
|
||||||
NotImplementedError,
|
|
||||||
match="`add_texts` is only supported for direct-access index.",
|
|
||||||
):
|
|
||||||
vectorsearch.add_texts(INPUT_TEXTS)
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_uuid(val: str) -> bool:
|
|
||||||
try:
|
|
||||||
uuid.UUID(str(val))
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_texts() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
ids = [idx for idx, i in enumerate(INPUT_TEXTS)]
|
|
||||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
|
||||||
|
|
||||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids)
|
|
||||||
vectorsearch.index.upsert.assert_called_once_with(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": id_,
|
|
||||||
"text": text,
|
|
||||||
"text_vector": vector,
|
|
||||||
}
|
|
||||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(added_ids) == len(INPUT_TEXTS)
|
|
||||||
assert added_ids == ids
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_texts_handle_single_text() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
|
||||||
|
|
||||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS[0])
|
|
||||||
vectorsearch.index.upsert.assert_called_once_with(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": id_,
|
|
||||||
"text": text,
|
|
||||||
"text_vector": vector,
|
|
||||||
}
|
|
||||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(added_ids) == 1
|
|
||||||
assert is_valid_uuid(added_ids[0])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_texts_with_default_id() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
|
||||||
|
|
||||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS)
|
|
||||||
vectorsearch.index.upsert.assert_called_once_with(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": id_,
|
|
||||||
"text": text,
|
|
||||||
"text_vector": vector,
|
|
||||||
}
|
|
||||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(added_ids) == len(INPUT_TEXTS)
|
|
||||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
|
||||||
|
|
||||||
|
|
||||||
def test_add_texts_with_metadata() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
|
||||||
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))]
|
|
||||||
|
|
||||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas)
|
|
||||||
vectorsearch.index.upsert.assert_called_once_with(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"id": id_,
|
|
||||||
"text": text,
|
|
||||||
"text_vector": vector,
|
|
||||||
**metadata, # type: ignore[arg-type]
|
|
||||||
}
|
|
||||||
for text, vector, id_, metadata in zip(
|
|
||||||
INPUT_TEXTS, vectors, added_ids, metadatas
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
assert len(added_ids) == len(INPUT_TEXTS)
|
|
||||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_embeddings_property(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
assert vectorsearch.embeddings == EMBEDDING_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
vectorsearch.delete(["some id"])
|
|
||||||
vectorsearch.index.delete.assert_called_once_with(["some id"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_delete_fail_no_ids() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
with pytest.raises(ValueError, match="ids must be provided."):
|
|
||||||
vectorsearch.delete()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
|
||||||
def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
with pytest.raises(
|
|
||||||
NotImplementedError, match="`delete` is only supported for direct-access"
|
|
||||||
):
|
|
||||||
vectorsearch.delete(["some id"])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
|
||||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
|
||||||
def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
query = "foo"
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 7
|
|
||||||
|
|
||||||
search_result = vectorsearch.similarity_search(
|
|
||||||
query, k=limit, filter=filters, query_type=query_type
|
|
||||||
)
|
|
||||||
if index_name == DELTA_SYNC_INDEX:
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_text=query,
|
|
||||||
query_vector=None,
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_text=None,
|
|
||||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type=query_type,
|
|
||||||
)
|
|
||||||
assert len(search_result) == len(INPUT_TEXTS)
|
|
||||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
|
||||||
assert all(["id" in d.metadata for d in search_result])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
|
||||||
def test_similarity_search_hybrid(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
query = "foo"
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 7
|
|
||||||
|
|
||||||
search_result = vectorsearch.similarity_search(
|
|
||||||
query, k=limit, filter=filters, query_type="HYBRID"
|
|
||||||
)
|
|
||||||
if index_name == DELTA_SYNC_INDEX:
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_text=query,
|
|
||||||
query_vector=None,
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type="HYBRID",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_text=query,
|
|
||||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type="HYBRID",
|
|
||||||
)
|
|
||||||
assert len(search_result) == len(INPUT_TEXTS)
|
|
||||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
|
||||||
assert all(["id" in d.metadata for d in search_result])
|
|
||||||
|
|
||||||
|
|
||||||
def test_similarity_search_both_filter_and_filters_passed() -> None:
|
|
||||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
query = "foo"
|
|
||||||
filter = {"some filter": True}
|
|
||||||
filters = {"some other filter": False}
|
|
||||||
|
|
||||||
vectorsearch.similarity_search(query, filter=filter, filters=filters)
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
|
||||||
# `filter` should prevail over `filters`
|
|
||||||
filters=filter,
|
|
||||||
num_results=4,
|
|
||||||
query_text=None,
|
|
||||||
query_type=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"columns, expected_columns",
|
|
||||||
[
|
|
||||||
(None, {"id"}),
|
|
||||||
(["id", "text", "text_vector"], {"text_vector", "id"}),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_mmr_search(
|
|
||||||
index_name: str, columns: Optional[List[str]], expected_columns: Set[str]
|
|
||||||
) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name, columns=columns)
|
|
||||||
|
|
||||||
query = INPUT_TEXTS[0]
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 1
|
|
||||||
|
|
||||||
search_result = vectorsearch.max_marginal_relevance_search(
|
|
||||||
query, k=limit, filters=filters
|
|
||||||
)
|
|
||||||
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
|
|
||||||
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_mmr_parameters(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
|
|
||||||
query = INPUT_TEXTS[0]
|
|
||||||
limit = 1
|
|
||||||
fetch_k = 3
|
|
||||||
lambda_mult = 0.25
|
|
||||||
filters = {"some filter": True}
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"langchain_databricks.vectorstores.maximal_marginal_relevance"
|
|
||||||
) as mock_mmr:
|
|
||||||
mock_mmr.return_value = [2]
|
|
||||||
retriever = vectorsearch.as_retriever(
|
|
||||||
search_type="mmr",
|
|
||||||
search_kwargs={
|
|
||||||
"k": limit,
|
|
||||||
"fetch_k": fetch_k,
|
|
||||||
"lambda_mult": lambda_mult,
|
|
||||||
"filter": filters,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
search_result = retriever.invoke(query)
|
|
||||||
|
|
||||||
mock_mmr.assert_called_once()
|
|
||||||
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
|
||||||
assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k
|
|
||||||
assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters
|
|
||||||
assert len(search_result) == limit
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
|
||||||
@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8])
|
|
||||||
def test_similarity_score_threshold(index_name: str, threshold: float) -> None:
|
|
||||||
query = INPUT_TEXTS[0]
|
|
||||||
limit = len(INPUT_TEXTS)
|
|
||||||
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
retriever = vectorsearch.as_retriever(
|
|
||||||
search_type="similarity_score_threshold",
|
|
||||||
search_kwargs={"k": limit, "score_threshold": threshold},
|
|
||||||
)
|
|
||||||
search_result = retriever.invoke(query)
|
|
||||||
if threshold <= 0.5:
|
|
||||||
assert len(search_result) == len(INPUT_TEXTS)
|
|
||||||
else:
|
|
||||||
assert len(search_result) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_standard_params() -> None:
|
|
||||||
vectorstore = init_vector_search(DIRECT_ACCESS_INDEX)
|
|
||||||
retriever = vectorstore.as_retriever()
|
|
||||||
ls_params = retriever._get_ls_params()
|
|
||||||
assert ls_params == {
|
|
||||||
"ls_retriever_name": "vectorstore",
|
|
||||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
|
||||||
"ls_embedding_provider": "FakeEmbeddings",
|
|
||||||
}
|
|
||||||
|
|
||||||
vectorstore = init_vector_search(DELTA_SYNC_INDEX)
|
|
||||||
retriever = vectorstore.as_retriever()
|
|
||||||
ls_params = retriever._get_ls_params()
|
|
||||||
assert ls_params == {
|
|
||||||
"ls_retriever_name": "vectorstore",
|
|
||||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
|
||||||
def test_similarity_search_by_vector(
|
|
||||||
index_name: str, query_type: Optional[str]
|
|
||||||
) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 7
|
|
||||||
|
|
||||||
search_result = vectorsearch.similarity_search_by_vector(
|
|
||||||
query_embedding, k=limit, filter=filters, query_type=query_type
|
|
||||||
)
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_vector=query_embedding,
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type=query_type,
|
|
||||||
query_text=None,
|
|
||||||
)
|
|
||||||
assert len(search_result) == len(INPUT_TEXTS)
|
|
||||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
|
||||||
assert all(["id" in d.metadata for d in search_result])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
|
||||||
def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 7
|
|
||||||
|
|
||||||
search_result = vectorsearch.similarity_search_by_vector(
|
|
||||||
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
|
|
||||||
)
|
|
||||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
|
||||||
columns=["id", "text"],
|
|
||||||
query_vector=query_embedding,
|
|
||||||
filters=filters,
|
|
||||||
num_results=limit,
|
|
||||||
query_type="HYBRID",
|
|
||||||
query_text="foo",
|
|
||||||
)
|
|
||||||
assert len(search_result) == len(INPUT_TEXTS)
|
|
||||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
|
||||||
assert all(["id" in d.metadata for d in search_result])
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
|
||||||
def test_similarity_search_empty_result(index_name: str) -> None:
|
|
||||||
vectorsearch = init_vector_search(index_name)
|
|
||||||
vectorsearch.index.similarity_search.return_value = {
|
|
||||||
"manifest": {
|
|
||||||
"column_count": 3,
|
|
||||||
"columns": [
|
|
||||||
{"name": "id"},
|
|
||||||
{"name": "text"},
|
|
||||||
{"name": "score"},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"result": {
|
|
||||||
"row_count": 0,
|
|
||||||
"data_array": [],
|
|
||||||
},
|
|
||||||
"next_page_token": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
search_result = vectorsearch.similarity_search("foo")
|
|
||||||
assert len(search_result) == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
|
|
||||||
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
|
|
||||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
|
||||||
filters = {"some filter": True}
|
|
||||||
limit = 7
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
NotImplementedError, match="`similarity_search_by_vector` is not supported"
|
|
||||||
):
|
|
||||||
vectorsearch.similarity_search_by_vector(
|
|
||||||
query_embedding, k=limit, filters=filters
|
|
||||||
)
|
|
Loading…
Reference in New Issue
Block a user