mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 04:50:37 +00:00
databricks: mv to partner repo (#25788)
This commit is contained in:
parent
2e5c379632
commit
1023fbc98a
1
libs/partners/databricks/.gitignore
vendored
1
libs/partners/databricks/.gitignore
vendored
@ -1 +0,0 @@
|
||||
__pycache__
|
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,62 +0,0 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
||||
|
||||
|
||||
# unit tests are run with the --disable-socket flag to prevent network calls
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
# integration tests are run without the --disable-socket flag to allow network calls
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/databricks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_databricks
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff check .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff check --select I $(PYTHON_FILES)
|
||||
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_databricks -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,24 +0,0 @@
|
||||
# langchain-databricks
|
||||
|
||||
This package contains the LangChain integration with Databricks
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-databricks
|
||||
```
|
||||
|
||||
And you should configure credentials by setting the following environment variables:
|
||||
|
||||
* TODO: fill this out
|
||||
|
||||
## Chat Models
|
||||
|
||||
`ChatDatabricks` class exposes chat models from Databricks.
|
||||
|
||||
```python
|
||||
from langchain_databricks import ChatDatabricks
|
||||
|
||||
llm = ChatDatabricks()
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
@ -1,19 +0,0 @@
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_databricks.chat_models import ChatDatabricks
|
||||
from langchain_databricks.embeddings import DatabricksEmbeddings
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
__all__ = [
|
||||
"ChatDatabricks",
|
||||
"DatabricksEmbeddings",
|
||||
"DatabricksVectorSearch",
|
||||
"__version__",
|
||||
]
|
@ -1,556 +0,0 @@
|
||||
"""Databricks chat models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.base import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
from langchain_databricks.utils import get_deployment_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDatabricks(BaseChatModel):
|
||||
"""Databricks chat model integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — completion params:
|
||||
endpoint: str
|
||||
Name of Databricks Model Serving endpoint to query.
|
||||
target_uri: str
|
||||
The target URI to use. Defaults to ``databricks``.
|
||||
temperature: float
|
||||
Sampling temperature. Higher values make the model more creative.
|
||||
n: Optional[int]
|
||||
The number of completion choices to generate.
|
||||
stop: Optional[List[str]]
|
||||
List of strings to stop generation at.
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
extra_params: Optional[Dict[str, Any]]
|
||||
Any extra parameters to pass to the endpoint.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_databricks import ChatDatabricks
|
||||
llm = ChatDatabricks(
|
||||
endpoint="databricks-meta-llama-3-1-405b-instruct",
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
("system", "You are a helpful translator. Translate the user sentence to French."),
|
||||
("human", "I love programming."),
|
||||
]
|
||||
llm.invoke(messages)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'prompt_tokens': 32,
|
||||
'completion_tokens': 9,
|
||||
'total_tokens': 41
|
||||
},
|
||||
id='run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0'
|
||||
)
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in llm.stream(messages):
|
||||
print(chunk)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = llm.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
full
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessageChunk(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'finish_reason': 'stop'
|
||||
},
|
||||
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
|
||||
)
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
await llm.ainvoke(messages)
|
||||
|
||||
# stream:
|
||||
# async for chunk in llm.astream(messages)
|
||||
|
||||
# batch:
|
||||
# await llm.abatch([messages])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'prompt_tokens': 32,
|
||||
'completion_tokens': 9,
|
||||
'total_tokens': 41
|
||||
},
|
||||
id='run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0'
|
||||
)
|
||||
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
||||
ai_msg.tool_calls
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
{
|
||||
'name': 'GetWeather',
|
||||
'args': {
|
||||
'location': 'Los Angeles, CA'
|
||||
},
|
||||
'id': 'call_ea0a6004-8e64-4ae8-a192-a40e295bfa24',
|
||||
'type': 'tool_call'
|
||||
}
|
||||
]
|
||||
|
||||
To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint: str
|
||||
"""Name of Databricks Model Serving endpoint to query."""
|
||||
target_uri: str = "databricks"
|
||||
"""The target URI to use. Defaults to ``databricks``."""
|
||||
temperature: float = 0.0
|
||||
"""Sampling temperature. Higher values make the model more creative."""
|
||||
n: int = 1
|
||||
"""The number of completion choices to generate."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""List of strings to stop generation at."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
extra_params: dict = Field(default_factory=dict)
|
||||
"""Any extra parameters to pass to the endpoint."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._client = get_deployment_client(self.target_uri)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"target_uri": self.target_uri,
|
||||
"endpoint": self.endpoint,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_params": self.extra_params,
|
||||
}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return self._convert_response_to_chat_result(resp)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
data: Dict[str, Any] = {
|
||||
"messages": [_convert_message_to_dict(msg) for msg in messages],
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
**self.extra_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop := self.stop or stop:
|
||||
data["stop"] = stop
|
||||
if self.max_tokens is not None:
|
||||
data["max_tokens"] = self.max_tokens
|
||||
|
||||
return data
|
||||
|
||||
def _convert_response_to_chat_result(
|
||||
self, response: Mapping[str, Any]
|
||||
) -> ChatResult:
|
||||
generations = [
|
||||
ChatGeneration(
|
||||
message=_convert_dict_to_message(choice["message"]),
|
||||
generation_info=choice.get("usage", {}),
|
||||
)
|
||||
for choice in response["choices"]
|
||||
]
|
||||
usage = response.get("usage", {})
|
||||
return ChatResult(generations=generations, llm_output=usage)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
||||
first_chunk_role = None
|
||||
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
|
||||
if chunk["choices"]:
|
||||
choice = chunk["choices"][0]
|
||||
|
||||
chunk_delta = choice["delta"]
|
||||
if first_chunk_role is None:
|
||||
first_chunk_role = chunk_delta.get("role")
|
||||
|
||||
chunk_message = _convert_dict_to_message_chunk(
|
||||
chunk_delta, first_chunk_role
|
||||
)
|
||||
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if logprobs := choice.get("logprobs"):
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk_message, generation_info=generation_info or None
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Handle the case where choices are empty if needed
|
||||
continue
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Options are:
|
||||
name of the tool (str): calls corresponding tool;
|
||||
"auto": automatically selects a tool (including no tool);
|
||||
"none": model does not generate any tool calls and instead must
|
||||
generate a standard assistant message;
|
||||
"required": the model picks the most relevant tool in tools and
|
||||
must generate a tool call;
|
||||
|
||||
or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, str):
|
||||
# tool_choice is a tool/function name
|
||||
if tool_choice not in ("auto", "none", "required"):
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
for formatted_tool in formatted_tools
|
||||
]
|
||||
if not any(
|
||||
tool_name == tool_choice["function"]["name"]
|
||||
for tool_name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "chat-databricks"
|
||||
|
||||
|
||||
### Conversion function to convert Pydantic models to dictionaries and vice versa. ###
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"content": message.content}
|
||||
|
||||
# OpenAI supports "name" field in messages.
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
||||
if id := message.id:
|
||||
message_dict["id"] = id
|
||||
|
||||
if isinstance(message, ChatMessage):
|
||||
return {"role": message.role, **message_dict}
|
||||
elif isinstance(message, HumanMessage):
|
||||
return {"role": "user", **message_dict}
|
||||
elif isinstance(message, AIMessage):
|
||||
if tool_calls := _get_tool_calls_from_ai_message(message):
|
||||
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
|
||||
# If tool calls present, content null value should be None not empty string.
|
||||
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
|
||||
return {"role": "assistant", **message_dict}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return {"role": "system", **message_dict}
|
||||
elif isinstance(message, ToolMessage):
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": message.tool_call_id,
|
||||
**message_dict,
|
||||
}
|
||||
elif (
|
||||
isinstance(message, FunctionMessage)
|
||||
or "function_call" in message.additional_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by Databricks. Please"
|
||||
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown message type: {type(message)}")
|
||||
|
||||
|
||||
def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
|
||||
tool_calls = [
|
||||
{
|
||||
"type": "function",
|
||||
"id": tc["id"],
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["args"]),
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
]
|
||||
|
||||
invalid_tool_calls = [
|
||||
{
|
||||
"type": "function",
|
||||
"id": tc["id"],
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["args"],
|
||||
},
|
||||
}
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
|
||||
if tool_calls or invalid_tool_calls:
|
||||
return tool_calls + invalid_tool_calls
|
||||
|
||||
# Get tool calls from additional kwargs if present.
|
||||
return [
|
||||
{
|
||||
k: v
|
||||
for k, v in tool_call.items() # type: ignore[union-attr]
|
||||
if k in {"id", "type", "function"}
|
||||
}
|
||||
for tool_call in message.additional_kwargs.get("tool_calls", [])
|
||||
]
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Dict) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict.get("content")
|
||||
content = content if content is not None else ""
|
||||
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "assistant":
|
||||
additional_kwargs: Dict = {}
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
def _convert_dict_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_role: str
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role", default_role)
|
||||
content = _dict.get("content")
|
||||
content = content if content is not None else ""
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "tool":
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
|
||||
)
|
||||
elif role == "assistant":
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=tc["function"].get("name"),
|
||||
args=tc["function"].get("arguments"),
|
||||
id=tc.get("id"),
|
||||
index=tc["index"],
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
else:
|
||||
return ChatMessageChunk(content=content, role=role)
|
@ -1,91 +0,0 @@
|
||||
from typing import Any, Dict, Iterator, List
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
from langchain_databricks.utils import get_deployment_client
|
||||
|
||||
|
||||
class DatabricksEmbeddings(Embeddings, BaseModel):
|
||||
"""Databricks embedding model integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace
|
||||
hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — completion params:
|
||||
endpoint: str
|
||||
Name of Databricks Model Serving endpoint to query.
|
||||
target_uri: str
|
||||
The target URI to use. Defaults to ``databricks``.
|
||||
query_params: Dict[str, str]
|
||||
The parameters to use for queries.
|
||||
documents_params: Dict[str, str]
|
||||
The parameters to use for documents.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
from langchain_databricks import DatabricksEmbeddings
|
||||
embed = DatabricksEmbeddings(
|
||||
endpoint="databricks-bge-large-en",
|
||||
)
|
||||
|
||||
Embed single text:
|
||||
.. code-block:: python
|
||||
input_text = "The meaning of life is 42"
|
||||
embed.embed_query(input_text)
|
||||
|
||||
.. code-block:: python
|
||||
[
|
||||
0.01605224609375,
|
||||
-0.0298309326171875,
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The endpoint to use."""
|
||||
target_uri: str = "databricks"
|
||||
"""The parameters to use for queries."""
|
||||
query_params: Dict[str, Any] = {}
|
||||
"""The parameters to use for documents."""
|
||||
documents_params: Dict[str, Any] = {}
|
||||
"""The target URI to use."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._client = get_deployment_client(self.target_uri)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts, params=self.documents_params)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed([text], params=self.query_params)[0]
|
||||
|
||||
def _embed(self, texts: List[str], params: Dict[str, str]) -> List[List[float]]:
|
||||
embeddings: List[List[float]] = []
|
||||
for txt in _chunk(texts, 20):
|
||||
resp = self._client.predict(
|
||||
endpoint=self.endpoint,
|
||||
inputs={"input": txt, **params}, # type: ignore[arg-type]
|
||||
)
|
||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||
return embeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
@ -1,101 +0,0 @@
|
||||
from typing import Any, List, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_deployment_client(target_uri: str) -> Any:
|
||||
if (target_uri != "databricks") and (urlparse(target_uri).scheme != "databricks"):
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
||||
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client # type: ignore[import-untyped]
|
||||
|
||||
return get_deploy_client(target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
"Please run `pip install mlflow` to install "
|
||||
"required dependencies."
|
||||
) from e
|
||||
|
||||
|
||||
# Utility function for Maximal Marginal Relevance (MMR) reranking.
|
||||
# Copied from langchain_community/vectorstores/utils.py to avoid cross-dependency
|
||||
Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray]
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
lambda_mult: float = 0.5,
|
||||
k: int = 4,
|
||||
) -> List[int]:
|
||||
"""Calculate maximal marginal relevance.
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding.
|
||||
embedding_list: List of embeddings to select from.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
|
||||
Returns:
|
||||
List of indices of embeddings selected by maximal marginal relevance.
|
||||
"""
|
||||
if min(k, len(embedding_list)) <= 0:
|
||||
return []
|
||||
if query_embedding.ndim == 1:
|
||||
query_embedding = np.expand_dims(query_embedding, axis=0)
|
||||
similarity_to_query = cosine_similarity(query_embedding, embedding_list)[0]
|
||||
most_similar = int(np.argmax(similarity_to_query))
|
||||
idxs = [most_similar]
|
||||
selected = np.array([embedding_list[most_similar]])
|
||||
while len(idxs) < min(k, len(embedding_list)):
|
||||
best_score = -np.inf
|
||||
idx_to_add = -1
|
||||
similarity_to_selected = cosine_similarity(embedding_list, selected)
|
||||
for i, query_score in enumerate(similarity_to_query):
|
||||
if i in idxs:
|
||||
continue
|
||||
redundant_score = max(similarity_to_selected[i])
|
||||
equation_score = (
|
||||
lambda_mult * query_score - (1 - lambda_mult) * redundant_score
|
||||
)
|
||||
if equation_score > best_score:
|
||||
best_score = equation_score
|
||||
idx_to_add = i
|
||||
idxs.append(idx_to_add)
|
||||
selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)
|
||||
return idxs
|
||||
|
||||
|
||||
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
|
||||
"""Row-wise cosine similarity between two equal-width matrices.
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of columns in X and Y are not the same.
|
||||
"""
|
||||
if len(X) == 0 or len(Y) == 0:
|
||||
return np.array([])
|
||||
|
||||
X = np.array(X)
|
||||
Y = np.array(Y)
|
||||
if X.shape[1] != Y.shape[1]:
|
||||
raise ValueError(
|
||||
"Number of columns in X and Y must be the same. X has shape"
|
||||
f"{X.shape} "
|
||||
f"and Y has shape {Y.shape}."
|
||||
)
|
||||
|
||||
X_norm = np.linalg.norm(X, axis=1)
|
||||
Y_norm = np.linalg.norm(Y, axis=1)
|
||||
# Ignore divide by zero errors run time warnings as those are handled below.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
|
||||
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
|
||||
return similarity
|
@ -1,837 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
|
||||
from langchain_databricks.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IndexType(str, Enum):
|
||||
DIRECT_ACCESS = "DIRECT_ACCESS"
|
||||
DELTA_SYNC = "DELTA_SYNC"
|
||||
|
||||
|
||||
_DIRECT_ACCESS_ONLY_MSG = "`%s` is only supported for direct-access index."
|
||||
_NON_MANAGED_EMB_ONLY_MSG = (
|
||||
"`%s` is not supported for index with Databricks-managed embeddings."
|
||||
)
|
||||
|
||||
|
||||
class DatabricksVectorSearch(VectorStore):
|
||||
"""Databricks vector store integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks`` and ``databricks-vectorsearch`` python packages.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks databricks-vectorsearch
|
||||
|
||||
If you don't have a Databricks Vector Search endpoint already, you can create one by following the instructions here: https://docs.databricks.com/en/generative-ai/create-query-vector-search.html
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace
|
||||
hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — indexing params:
|
||||
|
||||
endpoint: The name of the Databricks Vector Search endpoint.
|
||||
index_name: The name of the index to use. Format: "catalog.schema.index".
|
||||
embedding: The embedding model.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
text_column: The name of the text column to use for the embeddings.
|
||||
Required for direct-access index or delta-sync index
|
||||
with self-managed embeddings.
|
||||
Make sure the text column specified is in the index.
|
||||
columns: The list of column names to get when doing the search.
|
||||
Defaults to ``[primary_key, text_column]``.
|
||||
|
||||
Instantiate:
|
||||
|
||||
`DatabricksVectorSearch` supports two types of indexes:
|
||||
|
||||
* **Delta Sync Index** automatically syncs with a source Delta Table, automatically and incrementally updating the index as the underlying data in the Delta Table changes.
|
||||
|
||||
* **Direct Vector Access Index** supports direct read and write of vectors and metadata. The user is responsible for updating this table using the REST API or the Python SDK.
|
||||
|
||||
Also for delta-sync index, you can choose to use Databricks-managed embeddings or self-managed embeddings (via LangChain embeddings classes).
|
||||
|
||||
If you are using a delta-sync index with Databricks-managed embeddings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>"
|
||||
)
|
||||
|
||||
If you are using a direct-access index or a delta-sync index with self-managed embeddings,
|
||||
you also need to provide the embedding model and text column in your source table to
|
||||
use for the embeddings:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>",
|
||||
embedding=OpenAIEmbeddings(),
|
||||
text_column="document_content"
|
||||
)
|
||||
|
||||
Add Documents:
|
||||
.. code-block:: python
|
||||
from langchain_core.documents import Document
|
||||
|
||||
document_1 = Document(page_content="foo", metadata={"baz": "bar"})
|
||||
document_2 = Document(page_content="thud", metadata={"bar": "baz"})
|
||||
document_3 = Document(page_content="i will be deleted :(")
|
||||
documents = [document_1, document_2, document_3]
|
||||
ids = ["1", "2", "3"]
|
||||
vector_store.add_documents(documents=documents, ids=ids)
|
||||
|
||||
Delete Documents:
|
||||
.. code-block:: python
|
||||
vector_store.delete(ids=["3"])
|
||||
|
||||
.. note::
|
||||
|
||||
The `delete` method is only supported for direct-access index.
|
||||
|
||||
Search:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search(query="thud",k=1)
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* thud [{'id': '2'}]
|
||||
|
||||
.. note:
|
||||
|
||||
By default, similarity search only returns the primary key and text column.
|
||||
If you want to retrieve the custom metadata associated with the document,
|
||||
pass the additional columns in the `columns` parameter when initializing the vector store.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
vector_store = DatabricksVectorSearch(
|
||||
endpoint="<your-endpoint-name>",
|
||||
index_name="<your-index-name>",
|
||||
columns=["baz", "bar"],
|
||||
)
|
||||
|
||||
vector_store.similarity_search(query="thud",k=1)
|
||||
# Output: * thud [{'bar': 'baz', 'baz': None, 'id': '2'}]
|
||||
|
||||
Search with filter:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search(query="thud",k=1,filter={"bar": "baz"})
|
||||
for doc in results:
|
||||
print(f"* {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* thud [{'id': '2'}]
|
||||
|
||||
Search with score:
|
||||
.. code-block:: python
|
||||
results = vector_store.similarity_search_with_score(query="qux",k=1)
|
||||
for doc, score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* [SIM=0.748804] foo [{'id': '1'}]
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
# add documents
|
||||
await vector_store.aadd_documents(documents=documents, ids=ids)
|
||||
# delete documents
|
||||
await vector_store.adelete(ids=["3"])
|
||||
# search
|
||||
results = vector_store.asimilarity_search(query="thud",k=1)
|
||||
# search with score
|
||||
results = await vector_store.asimilarity_search_with_score(query="qux",k=1)
|
||||
for doc,score in results:
|
||||
print(f"* [SIM={score:3f}] {doc.page_content} [{doc.metadata}]")
|
||||
.. code-block:: python
|
||||
* [SIM=0.748807] foo [{'id': '1'}]
|
||||
|
||||
Use as Retriever:
|
||||
.. code-block:: python
|
||||
retriever = vector_store.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": 1, "fetch_k": 2, "lambda_mult": 0.5},
|
||||
)
|
||||
retriever.invoke("thud")
|
||||
.. code-block:: python
|
||||
[Document(metadata={'id': '2'}, page_content='thud')]
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str,
|
||||
index_name: str,
|
||||
embedding: Optional[Embeddings] = None,
|
||||
text_column: Optional[str] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
):
|
||||
try:
|
||||
from databricks.vector_search.client import ( # type: ignore[import]
|
||||
VectorSearchClient,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import databricks-vectorsearch python package. "
|
||||
"Please install it with `pip install databricks-vectorsearch`."
|
||||
) from e
|
||||
|
||||
self.index = VectorSearchClient().get_index(endpoint, index_name)
|
||||
self._index_details = IndexDetails(self.index)
|
||||
|
||||
_validate_embedding(embedding, self._index_details)
|
||||
self._embeddings = embedding
|
||||
self._text_column = _validate_and_get_text_column(
|
||||
text_column, self._index_details
|
||||
)
|
||||
self._columns = _validate_and_get_return_columns(
|
||||
columns or [], self._text_column, self._index_details
|
||||
)
|
||||
self._primary_key = self._index_details.primary_key
|
||||
|
||||
@property
|
||||
def embeddings(self) -> Optional[Embeddings]:
|
||||
"""Access the query embedding object if available."""
|
||||
return self._embeddings
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[VST],
|
||||
texts: List[str],
|
||||
embedding: Embeddings,
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> VST:
|
||||
raise NotImplementedError(
|
||||
"`from_texts` is not supported. "
|
||||
"Use `add_texts` to add to existing direct-access index."
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[Dict]] = None,
|
||||
ids: Optional[List[Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
"""Add texts to the index.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is only supported for a direct-access index.
|
||||
|
||||
Args:
|
||||
texts: List of texts to add.
|
||||
metadatas: List of metadata for each text. Defaults to None.
|
||||
ids: List of ids for each text. Defaults to None.
|
||||
If not provided, a random uuid will be generated for each text.
|
||||
|
||||
Returns:
|
||||
List of ids from adding the texts into the index.
|
||||
"""
|
||||
if self._index_details.is_delta_sync_index():
|
||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "add_texts")
|
||||
|
||||
# Wrap to list if input texts is a single string
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
texts = list(texts)
|
||||
vectors = self._embeddings.embed_documents(texts) # type: ignore[union-attr]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
||||
metadatas = metadatas or [{} for _ in texts]
|
||||
|
||||
updates = [
|
||||
{
|
||||
self._primary_key: id_,
|
||||
self._text_column: text,
|
||||
self._index_details.embedding_vector_column["name"]: vector,
|
||||
**metadata,
|
||||
}
|
||||
for text, vector, id_, metadata in zip(texts, vectors, ids, metadatas)
|
||||
]
|
||||
|
||||
upsert_resp = self.index.upsert(updates)
|
||||
if upsert_resp.get("status") in ("PARTIAL_SUCCESS", "FAILURE"):
|
||||
failed_ids = upsert_resp.get("result", dict()).get(
|
||||
"failed_primary_keys", []
|
||||
)
|
||||
if upsert_resp.get("status") == "FAILURE":
|
||||
logger.error("Failed to add texts to the index.")
|
||||
else:
|
||||
logger.warning("Some texts failed to be added to the index.")
|
||||
return [id_ for id_ in ids if id_ not in failed_ids]
|
||||
|
||||
return ids
|
||||
|
||||
async def aadd_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
metadatas: Optional[List[dict]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, partial(self.add_texts, **kwargs), texts, metadatas
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[Any]] = None, **kwargs: Any) -> Optional[bool]:
|
||||
"""Delete documents from the index.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is only supported for a direct-access index.
|
||||
|
||||
Args:
|
||||
ids: List of ids of documents to delete.
|
||||
|
||||
Returns:
|
||||
True if successful.
|
||||
"""
|
||||
if self._index_details.is_delta_sync_index():
|
||||
raise NotImplementedError(_DIRECT_ACCESS_ONLY_MSG % "delete")
|
||||
|
||||
if ids is None:
|
||||
raise ValueError("ids must be provided.")
|
||||
self.index.delete(ids)
|
||||
return True
|
||||
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
docs_with_score = self.similarity_search_with_score(
|
||||
query=query,
|
||||
k=k,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
async def asimilarity_search(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search, query, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to query, along with scores.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
query_text = query
|
||||
query_vector = None
|
||||
else:
|
||||
# The value for `query_text` needs to be specified only for hybrid search.
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
query_text = query
|
||||
else:
|
||||
query_text = None
|
||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self._columns,
|
||||
query_text=query_text,
|
||||
query_vector=query_vector,
|
||||
filters=filter,
|
||||
num_results=k,
|
||||
query_type=query_type,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
Databricks Vector search uses a normalized score 1/(1+d) where d
|
||||
is the L2 distance. Hence, we simply return the identity function.
|
||||
"""
|
||||
return lambda score: score
|
||||
|
||||
async def asimilarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_with_score, *args, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to embedding vector.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector"
|
||||
)
|
||||
|
||||
docs_with_score = self.similarity_search_by_vector_with_score(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
query=query,
|
||||
**kwargs,
|
||||
)
|
||||
return [doc for doc, _ in docs_with_score]
|
||||
|
||||
async def asimilarity_search_by_vector(
|
||||
self, embedding: List[float], k: int = 4, **kwargs: Any
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(self.similarity_search_by_vector, embedding, k=k, **kwargs)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def similarity_search_by_vector_with_score(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs most similar to embedding vector, along with scores.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the embedding and score for each.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "similarity_search_by_vector_with_score"
|
||||
)
|
||||
|
||||
if query_type is not None and query_type.upper() == "HYBRID":
|
||||
if query is None:
|
||||
raise ValueError(
|
||||
"A value for `query` must be specified for hybrid search."
|
||||
)
|
||||
query_text = query
|
||||
else:
|
||||
if query is not None:
|
||||
raise ValueError(
|
||||
(
|
||||
"Cannot specify both `embedding` and "
|
||||
'`query` unless `query_type="HYBRID"'
|
||||
)
|
||||
)
|
||||
query_text = None
|
||||
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=self._columns,
|
||||
query_vector=embedding,
|
||||
query_text=query_text,
|
||||
filters=filter,
|
||||
num_results=k,
|
||||
query_type=query_type,
|
||||
)
|
||||
return self._parse_search_response(search_resp)
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, Any]] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search"
|
||||
)
|
||||
|
||||
query_vector = self._embeddings.embed_query(query) # type: ignore[union-attr]
|
||||
docs = self.max_marginal_relevance_search_by_vector(
|
||||
query_vector,
|
||||
k,
|
||||
fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
query_type=query_type,
|
||||
)
|
||||
return docs
|
||||
|
||||
async def amax_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.max_marginal_relevance_search,
|
||||
query,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Any] = None,
|
||||
*,
|
||||
query_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
.. note::
|
||||
|
||||
This method is not supported for index with Databricks-managed embeddings.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter: Filters to apply to the query. Defaults to None.
|
||||
query_type: The type of this query. Supported values are "ANN" and "HYBRID".
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
if self._index_details.is_databricks_managed_embeddings():
|
||||
raise NotImplementedError(
|
||||
_NON_MANAGED_EMB_ONLY_MSG % "max_marginal_relevance_search_by_vector"
|
||||
)
|
||||
|
||||
embedding_column = self._index_details.embedding_vector_column["name"]
|
||||
search_resp = self.index.similarity_search(
|
||||
columns=list(set(self._columns + [embedding_column])),
|
||||
query_text=None,
|
||||
query_vector=embedding,
|
||||
filters=filter,
|
||||
num_results=fetch_k,
|
||||
query_type=query_type,
|
||||
)
|
||||
|
||||
embeddings_result_index = (
|
||||
search_resp.get("manifest").get("columns").index({"name": embedding_column})
|
||||
)
|
||||
embeddings = [
|
||||
doc[embeddings_result_index]
|
||||
for doc in search_resp.get("result").get("data_array")
|
||||
]
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
embeddings,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
ignore_cols: List = (
|
||||
[embedding_column] if embedding_column not in self._columns else []
|
||||
)
|
||||
candidates = self._parse_search_response(search_resp, ignore_cols=ignore_cols)
|
||||
selected_results = [r[0] for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
return selected_results
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_search_response(
|
||||
self, search_resp: Dict, ignore_cols: Optional[List[str]] = None
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Parse the search response into a list of Documents with score."""
|
||||
if ignore_cols is None:
|
||||
ignore_cols = []
|
||||
|
||||
columns = [
|
||||
col["name"]
|
||||
for col in search_resp.get("manifest", dict()).get("columns", [])
|
||||
]
|
||||
docs_with_score = []
|
||||
for result in search_resp.get("result", dict()).get("data_array", []):
|
||||
doc_id = result[columns.index(self._primary_key)]
|
||||
text_content = result[columns.index(self._text_column)]
|
||||
ignore_cols = [self._primary_key, self._text_column] + ignore_cols
|
||||
metadata = {
|
||||
col: value
|
||||
for col, value in zip(columns[:-1], result[:-1])
|
||||
if col not in ignore_cols
|
||||
}
|
||||
metadata[self._primary_key] = doc_id
|
||||
score = result[-1]
|
||||
doc = Document(page_content=text_content, metadata=metadata)
|
||||
docs_with_score.append((doc, score))
|
||||
return docs_with_score
|
||||
|
||||
|
||||
def _validate_and_get_text_column(
|
||||
text_column: Optional[str], index_details: IndexDetails
|
||||
) -> str:
|
||||
if index_details.is_databricks_managed_embeddings():
|
||||
index_source_column: str = index_details.embedding_source_column["name"]
|
||||
# check if input text column matches the source column of the index
|
||||
if text_column is not None:
|
||||
raise ValueError(
|
||||
f"The index '{index_details.name}' has the source column configured as "
|
||||
f"'{index_source_column}'. Do not pass the `text_column` parameter."
|
||||
)
|
||||
return index_source_column
|
||||
else:
|
||||
if text_column is None:
|
||||
raise ValueError("The `text_column` parameter is required for this index.")
|
||||
return text_column
|
||||
|
||||
|
||||
def _validate_and_get_return_columns(
|
||||
columns: List[str], text_column: str, index_details: IndexDetails
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get a list of columns to retrieve from the index.
|
||||
|
||||
If the index is direct-access index, validate the given columns against the schema.
|
||||
"""
|
||||
# add primary key column and source column if not in columns
|
||||
if index_details.primary_key not in columns:
|
||||
columns.append(index_details.primary_key)
|
||||
if text_column and text_column not in columns:
|
||||
columns.append(text_column)
|
||||
|
||||
# Validate specified columns are in the index
|
||||
if index_details.is_direct_access_index() and (
|
||||
index_schema := index_details.schema
|
||||
):
|
||||
if missing_columns := [c for c in columns if c not in index_schema]:
|
||||
raise ValueError(
|
||||
"Some columns specified in `columns` are not "
|
||||
f"in the index schema: {missing_columns}"
|
||||
)
|
||||
return columns
|
||||
|
||||
|
||||
def _validate_embedding(
|
||||
embedding: Optional[Embeddings], index_details: IndexDetails
|
||||
) -> None:
|
||||
if index_details.is_databricks_managed_embeddings():
|
||||
if embedding is not None:
|
||||
raise ValueError(
|
||||
f"The index '{index_details.name}' uses Databricks-managed embeddings. "
|
||||
"Do not pass the `embedding` parameter when initializing vector store."
|
||||
)
|
||||
else:
|
||||
if not embedding:
|
||||
raise ValueError(
|
||||
"The `embedding` parameter is required for a direct-access index "
|
||||
"or delta-sync index with self-managed embedding."
|
||||
)
|
||||
_validate_embedding_dimension(embedding, index_details)
|
||||
|
||||
|
||||
def _validate_embedding_dimension(
|
||||
embeddings: Embeddings, index_details: IndexDetails
|
||||
) -> None:
|
||||
"""validate if the embedding dimension matches with the index's configuration."""
|
||||
if index_embedding_dimension := index_details.embedding_vector_column.get(
|
||||
"embedding_dimension"
|
||||
):
|
||||
# Infer the embedding dimension from the embedding function."""
|
||||
actual_dimension = len(embeddings.embed_query("test"))
|
||||
if actual_dimension != index_embedding_dimension:
|
||||
raise ValueError(
|
||||
f"The specified embedding model's dimension '{actual_dimension}' does "
|
||||
f"not match with the index configuration '{index_embedding_dimension}'."
|
||||
)
|
||||
|
||||
|
||||
class IndexDetails:
|
||||
"""An utility class to store the configuration details of an index."""
|
||||
|
||||
def __init__(self, index: Any):
|
||||
self._index_details = index.describe()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._index_details["name"]
|
||||
|
||||
@property
|
||||
def schema(self) -> Optional[Dict]:
|
||||
if self.is_direct_access_index():
|
||||
schema_json = self.index_spec.get("schema_json")
|
||||
if schema_json is not None:
|
||||
return json.loads(schema_json)
|
||||
return None
|
||||
|
||||
@property
|
||||
def primary_key(self) -> str:
|
||||
return self._index_details["primary_key"]
|
||||
|
||||
@property
|
||||
def index_spec(self) -> Dict:
|
||||
return (
|
||||
self._index_details.get("delta_sync_index_spec", {})
|
||||
if self.is_delta_sync_index()
|
||||
else self._index_details.get("direct_access_index_spec", {})
|
||||
)
|
||||
|
||||
@property
|
||||
def embedding_vector_column(self) -> Dict:
|
||||
if vector_columns := self.index_spec.get("embedding_vector_columns"):
|
||||
return vector_columns[0]
|
||||
return {}
|
||||
|
||||
@property
|
||||
def embedding_source_column(self) -> Dict:
|
||||
if source_columns := self.index_spec.get("embedding_source_columns"):
|
||||
return source_columns[0]
|
||||
return {}
|
||||
|
||||
def is_delta_sync_index(self) -> bool:
|
||||
return self._index_details["index_type"] == IndexType.DELTA_SYNC.value
|
||||
|
||||
def is_direct_access_index(self) -> bool:
|
||||
return self._index_details["index_type"] == IndexType.DIRECT_ACCESS.value
|
||||
|
||||
def is_databricks_managed_embeddings(self) -> bool:
|
||||
return (
|
||||
self.is_delta_sync_index()
|
||||
and self.embedding_source_column.get("name") is not None
|
||||
)
|
2525
libs/partners/databricks/poetry.lock
generated
2525
libs/partners/databricks/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -1,100 +0,0 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-databricks"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Databricks and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/databricks"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22databricks%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
|
||||
python = ">=3.8.1,<3.12"
|
||||
langchain-core = "^0.2.0"
|
||||
mlflow = ">=2.9"
|
||||
|
||||
# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
|
||||
numpy = [
|
||||
{version = ">=1.26.0", python = ">=3.12"},
|
||||
{version = ">=1.24.0", python = "<3.12"},
|
||||
]
|
||||
scipy = [
|
||||
{version = ">=1.11", python = ">=3.12"},
|
||||
{version = "<2", python = "<3.12"}
|
||||
]
|
||||
databricks-vectorsearch = "^0.40"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-socket = "^0.7.0"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.6"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -1,17 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -1,18 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain, langchain_experimental, or langchain_community
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
@ -1,7 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,321 +0,0 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
from typing import Generator
|
||||
from unittest import mock
|
||||
|
||||
import mlflow # type: ignore # noqa: F401
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_databricks.chat_models import (
|
||||
ChatDatabricks,
|
||||
_convert_dict_to_message,
|
||||
_convert_dict_to_message_chunk,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
_MOCK_CHAT_RESPONSE = {
|
||||
"id": "chatcmpl_id",
|
||||
"object": "chat.completion",
|
||||
"created": 1721875529,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
||||
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
}
|
||||
|
||||
_MOCK_STREAM_RESPONSE = [
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "36939"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "x"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "8922.4"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": " = "},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_client() -> Generator:
|
||||
client = mock.MagicMock()
|
||||
client.predict.return_value = _MOCK_CHAT_RESPONSE
|
||||
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
|
||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> ChatDatabricks:
|
||||
return ChatDatabricks(
|
||||
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
||||
)
|
||||
|
||||
|
||||
def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
|
||||
res = llm.invoke(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "36939 * 8922.4"},
|
||||
]
|
||||
)
|
||||
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]
|
||||
|
||||
|
||||
def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
|
||||
res = llm.stream(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "36939 * 8922.4"},
|
||||
]
|
||||
)
|
||||
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
|
||||
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
|
||||
|
||||
|
||||
def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
|
||||
class GetWeather(BaseModel):
|
||||
"""Get the current weather in a given location"""
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
"""Get the current population in a given location"""
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
response = llm_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
|
||||
### Test data conversion functions ###
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_output"),
|
||||
[
|
||||
("user", HumanMessage("foo")),
|
||||
("system", SystemMessage("foo")),
|
||||
("assistant", AIMessage("foo")),
|
||||
("any_role", ChatMessage(content="foo", role="any_role")),
|
||||
],
|
||||
)
|
||||
def test_convert_message(role: str, expected_output: BaseMessage) -> None:
|
||||
message = {"role": role, "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == message
|
||||
|
||||
|
||||
def test_convert_message_with_tool_calls() -> None:
|
||||
ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12"
|
||||
tool_calls = [
|
||||
{
|
||||
"id": ID,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "main__test__python_exec",
|
||||
"arguments": '{"code": "result = 36939 * 8922.4"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
message_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls,
|
||||
"id": ID,
|
||||
}
|
||||
result = _convert_dict_to_message(message_with_tools)
|
||||
expected_output = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls},
|
||||
id=ID,
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
||||
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
||||
"id": ID,
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == message_with_tools
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_output"),
|
||||
[
|
||||
("user", HumanMessageChunk(content="foo")),
|
||||
("system", SystemMessageChunk(content="foo")),
|
||||
("assistant", AIMessageChunk(content="foo")),
|
||||
("any_role", ChatMessageChunk(content="foo", role="any_role")),
|
||||
],
|
||||
)
|
||||
def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None:
|
||||
delta = {"role": role, "content": "foo"}
|
||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == delta
|
||||
|
||||
|
||||
def test_convert_message_chunk_with_tool_calls() -> None:
|
||||
delta_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
||||
}
|
||||
result = _convert_dict_to_message_chunk(delta_with_tools, "role")
|
||||
expected_output = AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
||||
id=None,
|
||||
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_tool_message_chunk() -> None:
|
||||
delta = {
|
||||
"role": "tool",
|
||||
"content": "foo",
|
||||
"tool_call_id": "tool_call_id",
|
||||
"id": "some_id",
|
||||
}
|
||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
||||
expected_output = ToolMessageChunk(
|
||||
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == delta
|
||||
|
||||
|
||||
def test_convert_message_to_dict_function() -> None:
|
||||
with pytest.raises(ValueError, match="Function messages are not supported"):
|
||||
_convert_message_to_dict(FunctionMessage(content="", name="name"))
|
@ -1,69 +0,0 @@
|
||||
"""Test Together AI embeddings."""
|
||||
|
||||
from typing import Any, Dict, Generator
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from mlflow.deployments import BaseDeploymentClient # type: ignore[import-untyped]
|
||||
|
||||
from langchain_databricks import DatabricksEmbeddings
|
||||
|
||||
|
||||
def _mock_embeddings(endpoint: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": list(range(1536)),
|
||||
"index": 0,
|
||||
}
|
||||
for _ in inputs["input"]
|
||||
],
|
||||
"model": "text-embedding-3-small",
|
||||
"usage": {"prompt_tokens": 8, "total_tokens": 8},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client() -> Generator:
|
||||
client = mock.MagicMock()
|
||||
client.predict.side_effect = _mock_embeddings
|
||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embeddings() -> DatabricksEmbeddings:
|
||||
return DatabricksEmbeddings(
|
||||
endpoint="text-embedding-3-small",
|
||||
documents_params={"fruit": "apple"},
|
||||
query_params={"fruit": "banana"},
|
||||
)
|
||||
|
||||
|
||||
def test_embed_documents(
|
||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
||||
) -> None:
|
||||
documents = ["foo"] * 30
|
||||
output = embeddings.embed_documents(documents)
|
||||
assert len(output) == 30
|
||||
assert len(output[0]) == 1536
|
||||
assert mock_client.predict.call_count == 2
|
||||
assert all(
|
||||
call_arg[1]["inputs"]["fruit"] == "apple"
|
||||
for call_arg in mock_client().predict.call_args_list
|
||||
)
|
||||
|
||||
|
||||
def test_embed_query(
|
||||
mock_client: BaseDeploymentClient, embeddings: DatabricksEmbeddings
|
||||
) -> None:
|
||||
query = "foo bar"
|
||||
output = embeddings.embed_query(query)
|
||||
assert len(output) == 1536
|
||||
mock_client.predict.assert_called_once()
|
||||
assert mock_client.predict.call_args[1] == {
|
||||
"endpoint": "text-embedding-3-small",
|
||||
"inputs": {"input": [query], "fruit": "banana"},
|
||||
}
|
@ -1,12 +0,0 @@
|
||||
from langchain_databricks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatDatabricks",
|
||||
"DatabricksEmbeddings",
|
||||
"DatabricksVectorSearch",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,629 +0,0 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Generator, List, Optional, Set
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from langchain_databricks.vectorstores import DatabricksVectorSearch
|
||||
|
||||
INPUT_TEXTS = ["foo", "bar", "baz"]
|
||||
DEFAULT_VECTOR_DIMENSION = 4
|
||||
|
||||
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embeddings functionality for testing."""
|
||||
|
||||
def __init__(self, dimension: int = DEFAULT_VECTOR_DIMENSION):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
|
||||
def embed_documents(self, embedding_texts: List[str]) -> List[List[float]]:
|
||||
"""Return simple embeddings."""
|
||||
return [
|
||||
[float(1.0)] * (self.dimension - 1) + [float(i)]
|
||||
for i in range(len(embedding_texts))
|
||||
]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Return simple embeddings."""
|
||||
return [float(1.0)] * (self.dimension - 1) + [float(0.0)]
|
||||
|
||||
|
||||
EMBEDDING_MODEL = FakeEmbeddings()
|
||||
|
||||
|
||||
### Dummy similarity_search() Response ###
|
||||
EXAMPLE_SEARCH_RESPONSE = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": "id"},
|
||||
{"name": "text"},
|
||||
{"name": "text_vector"},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": len(INPUT_TEXTS),
|
||||
"data_array": sorted(
|
||||
[
|
||||
[str(uuid.uuid4()), s, e, 0.5]
|
||||
for s, e in zip(
|
||||
INPUT_TEXTS, EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
)
|
||||
],
|
||||
key=lambda x: x[2], # type: ignore
|
||||
reverse=True,
|
||||
),
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
|
||||
### Dummy Indices ####
|
||||
|
||||
ENDPOINT_NAME = "test-endpoint"
|
||||
DIRECT_ACCESS_INDEX = "test-direct-access-index"
|
||||
DELTA_SYNC_INDEX = "test-delta-sync-index"
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX = "test-delta-sync-self-managed-index"
|
||||
ALL_INDEX_NAMES = {
|
||||
DIRECT_ACCESS_INDEX,
|
||||
DELTA_SYNC_INDEX,
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
||||
}
|
||||
|
||||
INDEX_DETAILS = {
|
||||
DELTA_SYNC_INDEX: {
|
||||
"name": DELTA_SYNC_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": "id",
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_source_columns": [
|
||||
{
|
||||
"name": "text",
|
||||
"embedding_model_endpoint_name": "openai-text-embedding",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX: {
|
||||
"name": DELTA_SYNC_SELF_MANAGED_EMBEDDINGS_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DELTA_SYNC",
|
||||
"primary_key": "id",
|
||||
"delta_sync_index_spec": {
|
||||
"source_table": "ml.llm.source_table",
|
||||
"pipeline_type": "CONTINUOUS",
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": "text_vector",
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
DIRECT_ACCESS_INDEX: {
|
||||
"name": DIRECT_ACCESS_INDEX,
|
||||
"endpoint_name": ENDPOINT_NAME,
|
||||
"index_type": "DIRECT_ACCESS",
|
||||
"primary_key": "id",
|
||||
"direct_access_index_spec": {
|
||||
"embedding_vector_columns": [
|
||||
{
|
||||
"name": "text_vector",
|
||||
"embedding_dimension": DEFAULT_VECTOR_DIMENSION,
|
||||
}
|
||||
],
|
||||
"schema_json": f"{{"
|
||||
f'"{"id"}": "int", '
|
||||
f'"feat1": "str", '
|
||||
f'"feat2": "float", '
|
||||
f'"text": "string", '
|
||||
f'"{"text_vector"}": "array<float>"'
|
||||
f"}}",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vs_client() -> Generator:
|
||||
def _get_index(endpoint: str, index_name: str) -> MagicMock:
|
||||
from databricks.vector_search.client import VectorSearchIndex # type: ignore
|
||||
|
||||
if endpoint != ENDPOINT_NAME:
|
||||
raise ValueError(f"Unknown endpoint: {endpoint}")
|
||||
|
||||
index = MagicMock(spec=VectorSearchIndex)
|
||||
index.describe.return_value = INDEX_DETAILS[index_name]
|
||||
index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE
|
||||
return index
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_index.side_effect = _get_index
|
||||
with mock.patch(
|
||||
"databricks.vector_search.client.VectorSearchClient",
|
||||
return_value=mock_client,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def init_vector_search(
|
||||
index_name: str, columns: Optional[List[str]] = None
|
||||
) -> DatabricksVectorSearch:
|
||||
kwargs: Dict[str, Any] = {
|
||||
"endpoint": ENDPOINT_NAME,
|
||||
"index_name": index_name,
|
||||
"columns": columns,
|
||||
}
|
||||
if index_name != DELTA_SYNC_INDEX:
|
||||
kwargs.update(
|
||||
{
|
||||
"embedding": EMBEDDING_MODEL,
|
||||
"text_column": "text",
|
||||
}
|
||||
)
|
||||
return DatabricksVectorSearch(**kwargs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_init(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
assert vectorsearch.index.describe() == INDEX_DETAILS[index_name]
|
||||
|
||||
|
||||
def test_init_fail_text_column_mismatch() -> None:
|
||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' has"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=DELTA_SYNC_INDEX,
|
||||
text_column="some_other_column",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_no_text_column(index_name: str) -> None:
|
||||
with pytest.raises(ValueError, match="The `text_column` parameter is required"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
embedding=EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
|
||||
def test_init_fail_columns_not_in_schema() -> None:
|
||||
columns = ["some_random_column"]
|
||||
with pytest.raises(ValueError, match="Some columns specified in `columns`"):
|
||||
init_vector_search(DIRECT_ACCESS_INDEX, columns=columns)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_no_embedding(index_name: str) -> None:
|
||||
with pytest.raises(ValueError, match="The `embedding` parameter is required"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
text_column="text",
|
||||
)
|
||||
|
||||
|
||||
def test_init_fail_embedding_already_specified_in_source() -> None:
|
||||
with pytest.raises(ValueError, match=f"The index '{DELTA_SYNC_INDEX}' uses"):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=DELTA_SYNC_INDEX,
|
||||
embedding=EMBEDDING_MODEL,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_init_fail_embedding_dim_mismatch(index_name: str) -> None:
|
||||
with pytest.raises(
|
||||
ValueError, match="embedding model's dimension '1000' does not match"
|
||||
):
|
||||
DatabricksVectorSearch(
|
||||
endpoint=ENDPOINT_NAME,
|
||||
index_name=index_name,
|
||||
text_column="text",
|
||||
embedding=FakeEmbeddings(1000),
|
||||
)
|
||||
|
||||
|
||||
def test_from_texts_not_supported() -> None:
|
||||
with pytest.raises(NotImplementedError, match="`from_texts` is not supported"):
|
||||
DatabricksVectorSearch.from_texts(INPUT_TEXTS, EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
||||
def test_add_texts_not_supported_for_delta_sync_index(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="`add_texts` is only supported for direct-access index.",
|
||||
):
|
||||
vectorsearch.add_texts(INPUT_TEXTS)
|
||||
|
||||
|
||||
def is_valid_uuid(val: str) -> bool:
|
||||
try:
|
||||
uuid.UUID(str(val))
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def test_add_texts() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
ids = [idx for idx, i in enumerate(INPUT_TEXTS)]
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, ids=ids)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert added_ids == ids
|
||||
|
||||
|
||||
def test_add_texts_handle_single_text() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS[0])
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == 1
|
||||
assert is_valid_uuid(added_ids[0])
|
||||
|
||||
|
||||
def test_add_texts_with_default_id() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
}
|
||||
for text, vector, id_ in zip(INPUT_TEXTS, vectors, added_ids)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
def test_add_texts_with_metadata() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectors = EMBEDDING_MODEL.embed_documents(INPUT_TEXTS)
|
||||
metadatas = [{"feat1": str(i), "feat2": i + 1000} for i in range(len(INPUT_TEXTS))]
|
||||
|
||||
added_ids = vectorsearch.add_texts(INPUT_TEXTS, metadatas=metadatas)
|
||||
vectorsearch.index.upsert.assert_called_once_with(
|
||||
[
|
||||
{
|
||||
"id": id_,
|
||||
"text": text,
|
||||
"text_vector": vector,
|
||||
**metadata, # type: ignore[arg-type]
|
||||
}
|
||||
for text, vector, id_, metadata in zip(
|
||||
INPUT_TEXTS, vectors, added_ids, metadatas
|
||||
)
|
||||
]
|
||||
)
|
||||
assert len(added_ids) == len(INPUT_TEXTS)
|
||||
assert all([is_valid_uuid(id_) for id_ in added_ids])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_embeddings_property(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
assert vectorsearch.embeddings == EMBEDDING_MODEL
|
||||
|
||||
|
||||
def test_delete() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
vectorsearch.delete(["some id"])
|
||||
vectorsearch.index.delete.assert_called_once_with(["some id"])
|
||||
|
||||
|
||||
def test_delete_fail_no_ids() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
with pytest.raises(ValueError, match="ids must be provided."):
|
||||
vectorsearch.delete()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DIRECT_ACCESS_INDEX})
|
||||
def test_delete_not_supported_for_delta_sync_index(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="`delete` is only supported for direct-access"
|
||||
):
|
||||
vectorsearch.delete(["some id"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
||||
def test_similarity_search(index_name: str, query_type: Optional[str]) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
if index_name == DELTA_SYNC_INDEX:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
)
|
||||
else:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=None,
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_similarity_search_hybrid(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query = "foo"
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search(
|
||||
query, k=limit, filter=filters, query_type="HYBRID"
|
||||
)
|
||||
if index_name == DELTA_SYNC_INDEX:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=None,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
else:
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_text=query,
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
def test_similarity_search_both_filter_and_filters_passed() -> None:
|
||||
vectorsearch = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
query = "foo"
|
||||
filter = {"some filter": True}
|
||||
filters = {"some other filter": False}
|
||||
|
||||
vectorsearch.similarity_search(query, filter=filter, filters=filters)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=EMBEDDING_MODEL.embed_query(query),
|
||||
# `filter` should prevail over `filters`
|
||||
filters=filter,
|
||||
num_results=4,
|
||||
query_text=None,
|
||||
query_type=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
@pytest.mark.parametrize(
|
||||
"columns, expected_columns",
|
||||
[
|
||||
(None, {"id"}),
|
||||
(["id", "text", "text_vector"], {"text_vector", "id"}),
|
||||
],
|
||||
)
|
||||
def test_mmr_search(
|
||||
index_name: str, columns: Optional[List[str]], expected_columns: Set[str]
|
||||
) -> None:
|
||||
vectorsearch = init_vector_search(index_name, columns=columns)
|
||||
|
||||
query = INPUT_TEXTS[0]
|
||||
filters = {"some filter": True}
|
||||
limit = 1
|
||||
|
||||
search_result = vectorsearch.max_marginal_relevance_search(
|
||||
query, k=limit, filters=filters
|
||||
)
|
||||
assert [doc.page_content for doc in search_result] == [INPUT_TEXTS[0]]
|
||||
assert [set(doc.metadata.keys()) for doc in search_result] == [expected_columns]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_mmr_parameters(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
|
||||
query = INPUT_TEXTS[0]
|
||||
limit = 1
|
||||
fetch_k = 3
|
||||
lambda_mult = 0.25
|
||||
filters = {"some filter": True}
|
||||
|
||||
with patch(
|
||||
"langchain_databricks.vectorstores.maximal_marginal_relevance"
|
||||
) as mock_mmr:
|
||||
mock_mmr.return_value = [2]
|
||||
retriever = vectorsearch.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={
|
||||
"k": limit,
|
||||
"fetch_k": fetch_k,
|
||||
"lambda_mult": lambda_mult,
|
||||
"filter": filters,
|
||||
},
|
||||
)
|
||||
search_result = retriever.invoke(query)
|
||||
|
||||
mock_mmr.assert_called_once()
|
||||
assert mock_mmr.call_args[1]["lambda_mult"] == lambda_mult
|
||||
assert vectorsearch.index.similarity_search.call_args[1]["num_results"] == fetch_k
|
||||
assert vectorsearch.index.similarity_search.call_args[1]["filters"] == filters
|
||||
assert len(search_result) == limit
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
@pytest.mark.parametrize("threshold", [0.4, 0.5, 0.8])
|
||||
def test_similarity_score_threshold(index_name: str, threshold: float) -> None:
|
||||
query = INPUT_TEXTS[0]
|
||||
limit = len(INPUT_TEXTS)
|
||||
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
retriever = vectorsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": limit, "score_threshold": threshold},
|
||||
)
|
||||
search_result = retriever.invoke(query)
|
||||
if threshold <= 0.5:
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
else:
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
def test_standard_params() -> None:
|
||||
vectorstore = init_vector_search(DIRECT_ACCESS_INDEX)
|
||||
retriever = vectorstore.as_retriever()
|
||||
ls_params = retriever._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_retriever_name": "vectorstore",
|
||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
||||
"ls_embedding_provider": "FakeEmbeddings",
|
||||
}
|
||||
|
||||
vectorstore = init_vector_search(DELTA_SYNC_INDEX)
|
||||
retriever = vectorstore.as_retriever()
|
||||
ls_params = retriever._get_ls_params()
|
||||
assert ls_params == {
|
||||
"ls_retriever_name": "vectorstore",
|
||||
"ls_vector_store_provider": "DatabricksVectorSearch",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
@pytest.mark.parametrize("query_type", [None, "ANN"])
|
||||
def test_similarity_search_by_vector(
|
||||
index_name: str, query_type: Optional[str]
|
||||
) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters, query_type=query_type
|
||||
)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type=query_type,
|
||||
query_text=None,
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES - {DELTA_SYNC_INDEX})
|
||||
def test_similarity_search_by_vector_hybrid(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
search_result = vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filter=filters, query_type="HYBRID", query="foo"
|
||||
)
|
||||
vectorsearch.index.similarity_search.assert_called_once_with(
|
||||
columns=["id", "text"],
|
||||
query_vector=query_embedding,
|
||||
filters=filters,
|
||||
num_results=limit,
|
||||
query_type="HYBRID",
|
||||
query_text="foo",
|
||||
)
|
||||
assert len(search_result) == len(INPUT_TEXTS)
|
||||
assert sorted([d.page_content for d in search_result]) == sorted(INPUT_TEXTS)
|
||||
assert all(["id" in d.metadata for d in search_result])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES)
|
||||
def test_similarity_search_empty_result(index_name: str) -> None:
|
||||
vectorsearch = init_vector_search(index_name)
|
||||
vectorsearch.index.similarity_search.return_value = {
|
||||
"manifest": {
|
||||
"column_count": 3,
|
||||
"columns": [
|
||||
{"name": "id"},
|
||||
{"name": "text"},
|
||||
{"name": "score"},
|
||||
],
|
||||
},
|
||||
"result": {
|
||||
"row_count": 0,
|
||||
"data_array": [],
|
||||
},
|
||||
"next_page_token": "",
|
||||
}
|
||||
|
||||
search_result = vectorsearch.similarity_search("foo")
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
def test_similarity_search_by_vector_not_supported_for_managed_embedding() -> None:
|
||||
vectorsearch = init_vector_search(DELTA_SYNC_INDEX)
|
||||
query_embedding = EMBEDDING_MODEL.embed_query("foo")
|
||||
filters = {"some filter": True}
|
||||
limit = 7
|
||||
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="`similarity_search_by_vector` is not supported"
|
||||
):
|
||||
vectorsearch.similarity_search_by_vector(
|
||||
query_embedding, k=limit, filters=filters
|
||||
)
|
Loading…
Reference in New Issue
Block a user