databricks: Add partner package directory and ChatDatabricks implementation (#25430)

### Summary

Create `langchain-databricks` as a new partner packages. This PR does
not migrate all existing Databricks integration, but the package will
eventually contain:

* `ChatDatabricks` (implemented in this PR)
* `DatabricksVectorSearch`
* `DatabricksEmbeddings`
* ~`UCFunctionToolkit`~ (will be done after UC SDK work which
drastically simplify implementation)

Also, this PR does not add integration tests yet. This will be added
once the Databricks test workspace is ready.

Tagging @efriis as POC


### Tracker
[✍️] Create a package and imgrate ChatDatabricks
[ ] Migrate DatabricksVectorSearch, DatabricksEmbeddings, and their docs
~[ ] Migrate UCFunctionToolkit and its doc~
[ ] Add provider document and update README.md
[ ] Add integration tests and set up secrets (after moved to an external
package)
[ ] Add deprecation note to the community implementations.

---------

Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Yuki Watanabe 2024-08-22 09:19:28 +09:00 committed by GitHub
parent fb1d67edf6
commit 3981d736df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 3699 additions and 13 deletions

View File

@ -18,7 +18,7 @@ for dir; do \
if find "$$dir" -maxdepth 1 -type f \( -name "pyproject.toml" -o -name "setup.py" \) | grep -q .; then \
echo "$$dir"; \
fi \
done' sh {} + | grep -vE "airbyte|ibm|couchbase" | tr '\n' ' ')
done' sh {} + | grep -vE "airbyte|ibm|couchbase|databricks" | tr '\n' ' ')
PORT ?= 3001

View File

@ -31,7 +31,7 @@
"\n",
"| Class | Package | Local | Serializable | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: |\n",
"| [ChatDatabricks](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.databricks.ChatDatabricks.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | beta | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-community?style=flat-square&label=%20) |\n",
"| [ChatDatabricks](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.databricks.ChatDatabricks.html) | [langchain-databricks](https://api.python.langchain.com/en/latest/databricks_api_reference.html) | ❌ | beta | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain-databricks?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain-databricks?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
@ -99,7 +99,7 @@
"source": [
"### Installation\n",
"\n",
"The LangChain Databricks integration lives in the `langchain-community` package. Also, `mlflow >= 2.9 ` is required to run the code in this notebook."
"The LangChain Databricks integration lives in the `langchain-databricks` package."
]
},
{
@ -108,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community mlflow>=2.9.0"
"%pip install -qU langchain-databricks"
]
},
{
@ -133,7 +133,7 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatDatabricks\n",
"from langchain_databricks import ChatDatabricks\n",
"\n",
"chat_model = ChatDatabricks(\n",
" endpoint=\"databricks-dbrx-instruct\",\n",
@ -245,9 +245,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation (streaming)\n",
"\n",
"`ChatDatabricks` supports streaming response by `stream` method since `langchain-community>=0.2.1`."
"## Invocation (streaming)"
]
},
{
@ -299,7 +297,7 @@
"* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html) via MLflow. The endpoint must have OpenAI-compatible chat input/output format ([reference](https://mlflow.org/docs/latest/llms/deployments/index.html#chat))\n",
"* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n",
"\n",
"Once the endpoint is ready, the usage pattern is completely same as Foundation Models."
"Once the endpoint is ready, the usage pattern is identical to that of Foundation Models."
]
},
{
@ -332,7 +330,7 @@
"\n",
"First, create a new Databricks serving endpoint that proxies requests to the target external model. The endpoint creation should be fairy quick for proxying external models.\n",
"\n",
"This requires registering OpenAI API Key in Databricks secret manager with the following comment:\n",
"This requires registering your OpenAI API Key within the Databricks secret manager as follows:\n",
"```sh\n",
"# Replace `<scope>` with your scope\n",
"databricks secrets create-scope <scope>\n",
@ -417,8 +415,6 @@
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
"\n",
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
"tools = [\n",
" {\n",
@ -461,7 +457,7 @@
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatDatabricks features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.ChatDatabricks.html"
"For detailed documentation of all ChatDatabricks features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_databricks.chat_models.ChatDatabricks.html"
]
}
],

1
libs/partners/databricks/.gitignore vendored Normal file
View File

@ -0,0 +1 @@
__pycache__

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 LangChain, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -0,0 +1,62 @@
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
# Default target executed when no arguments are given to make.
all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/
integration_test integration_tests: TEST_FILE = tests/integration_tests/
# unit tests are run with the --disable-socket flag to prevent network calls
test tests:
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
# integration tests are run without the --disable-socket flag to allow network calls
integration_test integration_tests:
poetry run pytest $(TEST_FILE)
######################
# LINTING AND FORMATTING
######################
# Define a variable for Python and notebook files.
PYTHON_FILES=.
MYPY_CACHE=.mypy_cache
lint format: PYTHON_FILES=.
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/databricks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
lint_package: PYTHON_FILES=langchain_databricks
lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test
lint lint_diff lint_package lint_tests:
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff check --select I $(PYTHON_FILES)
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)
spell_check:
poetry run codespell --toml pyproject.toml
spell_fix:
poetry run codespell --toml pyproject.toml -w
check_imports: $(shell find langchain_databricks -name '*.py')
poetry run python ./scripts/check_imports.py $^
######################
# HELP
######################
help:
@echo '----'
@echo 'check_imports - check imports'
@echo 'format - run code formatters'
@echo 'lint - run linters'
@echo 'test - run unit tests'
@echo 'tests - run unit tests'
@echo 'test TEST_FILE=<test_file> - run all tests in file'

View File

@ -0,0 +1,24 @@
# langchain-databricks
This package contains the LangChain integration with Databricks
## Installation
```bash
pip install -U langchain-databricks
```
And you should configure credentials by setting the following environment variables:
* TODO: fill this out
## Chat Models
`ChatDatabricks` class exposes chat models from Databricks.
```python
from langchain_databricks import ChatDatabricks
llm = ChatDatabricks()
llm.invoke("Sing a ballad of LangChain.")
```

View File

@ -0,0 +1,15 @@
from importlib import metadata
from langchain_databricks.chat_models import ChatDatabricks
try:
__version__ = metadata.version(__package__)
except metadata.PackageNotFoundError:
# Case where package metadata is not available.
__version__ = ""
del metadata # optional, avoids polluting the results of dir(__package__)
__all__ = [
"ChatDatabricks",
"__version__",
]

View File

@ -0,0 +1,573 @@
"""Databricks chat models."""
import json
import logging
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from urllib.parse import urlparse
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models.base import LanguageModelInput
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
)
from langchain_core.messages.tool import tool_call_chunk
from langchain_core.output_parsers.openai_tools import (
make_invalid_tool_call,
parse_tool_call,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import (
BaseModel,
Field,
PrivateAttr,
)
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
logger = logging.getLogger(__name__)
class ChatDatabricks(BaseChatModel):
"""Databricks chat model integration.
Setup:
Install ``langchain-databricks``.
.. code-block:: bash
pip install -U langchain-databricks
If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables:
.. code-block:: bash
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
export DATABRICKS_TOKEN="your-personal-access-token"
Key init args completion params:
endpoint: str
Name of Databricks Model Serving endpoint to query.
target_uri: str
The target URI to use. Defaults to ``databricks``.
temperature: float
Sampling temperature. Higher values make the model more creative.
n: Optional[int]
The number of completion choices to generate.
stop: Optional[List[str]]
List of strings to stop generation at.
max_tokens: Optional[int]
Max number of tokens to generate.
extra_params: Optional[Dict[str, Any]]
Any extra parameters to pass to the endpoint.
Instantiate:
.. code-block:: python
from langchain_databricks import ChatDatabricks
llm = ChatDatabricks(
endpoint="databricks-meta-llama-3-1-405b-instruct",
temperature=0,
max_tokens=500,
)
Invoke:
.. code-block:: python
messages = [
("system", "You are a helpful translator. Translate the user sentence to French."),
("human", "I love programming."),
]
llm.invoke(messages)
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'prompt_tokens': 32,
'completion_tokens': 9,
'total_tokens': 41
},
id='run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0'
)
Stream:
.. code-block:: python
for chunk in llm.stream(messages):
print(chunk)
.. code-block:: python
content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
.. code-block:: python
stream = llm.stream(messages)
full = next(stream)
for chunk in stream:
full += chunk
full
.. code-block:: python
AIMessageChunk(
content="J'adore la programmation.",
response_metadata={
'finish_reason': 'stop'
},
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
)
Async:
.. code-block:: python
await llm.ainvoke(messages)
# stream:
# async for chunk in llm.astream(messages)
# batch:
# await llm.abatch([messages])
.. code-block:: python
AIMessage(
content="J'adore la programmation.",
response_metadata={
'prompt_tokens': 32,
'completion_tokens': 9,
'total_tokens': 41
},
id='run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0'
)
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetWeather',
'args': {
'location': 'Los Angeles, CA'
},
'id': 'call_ea0a6004-8e64-4ae8-a192-a40e295bfa24',
'type': 'tool_call'
}
]
To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information.
""" # noqa: E501
endpoint: str
"""Name of Databricks Model Serving endpoint to query."""
target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""
temperature: float = 0.0
"""Sampling temperature. Higher values make the model more creative."""
n: int = 1
"""The number of completion choices to generate."""
stop: Optional[List[str]] = None
"""List of strings to stop generation at."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client # type: ignore
self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. Please run `pip install mlflow` to "
"install required dependencies."
) from e
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"target_uri": self.target_uri,
"endpoint": self.endpoint,
"temperature": self.temperature,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
data = self._prepare_inputs(messages, stop, **kwargs)
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return self._convert_response_to_chat_result(resp)
def _prepare_inputs(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict[str, Any]:
data: Dict[str, Any] = {
"messages": [_convert_message_to_dict(msg) for msg in messages],
"temperature": self.temperature,
"n": self.n,
**self.extra_params,
**kwargs,
}
if stop := self.stop or stop:
data["stop"] = stop
if self.max_tokens is not None:
data["max_tokens"] = self.max_tokens
return data
def _convert_response_to_chat_result(
self, response: Mapping[str, Any]
) -> ChatResult:
generations = [
ChatGeneration(
message=_convert_dict_to_message(choice["message"]),
generation_info=choice.get("usage", {}),
)
for choice in response["choices"]
]
usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
data = self._prepare_inputs(messages, stop, **kwargs)
first_chunk_role = None
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
if chunk["choices"]:
choice = chunk["choices"][0]
chunk_delta = choice["delta"]
if first_chunk_role is None:
first_chunk_role = chunk_delta.get("role")
chunk_message = _convert_dict_to_message_chunk(
chunk_delta, first_chunk_role
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
if logprobs := choice.get("logprobs"):
generation_info["logprobs"] = logprobs
chunk = ChatGenerationChunk(
message=chunk_message, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(
chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
else:
# Handle the case where choices are empty if needed
continue
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Assumes model is compatible with OpenAI tool-calling API.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
tool_choice: Which tool to require the model to call.
Options are:
name of the tool (str): calls corresponding tool;
"auto": automatically selects a tool (including no tool);
"none": model does not generate any tool calls and instead must
generate a standard assistant message;
"required": the model picks the most relevant tool in tools and
must generate a tool call;
or a dict of the form:
{"type": "function", "function": {"name": <<tool_name>>}}.
**kwargs: Any additional parameters to pass to the
:class:`~langchain.runnable.Runnable` constructor.
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "required"):
tool_choice = {
"type": "function",
"function": {"name": tool_choice},
}
elif isinstance(tool_choice, dict):
tool_names = [
formatted_tool["function"]["name"]
for formatted_tool in formatted_tools
]
if not any(
tool_name == tool_choice["function"]["name"]
for tool_name in tool_names
):
raise ValueError(
f"Tool choice {tool_choice} was specified, but the only "
f"provided tools were {tool_names}."
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool or dict. "
f"Received: {tool_choice}"
)
kwargs["tool_choice"] = tool_choice
return super().bind(tools=formatted_tools, **kwargs)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-databricks"
### Conversion function to convert Pydantic models to dictionaries and vice versa. ###
def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"content": message.content}
# OpenAI supports "name" field in messages.
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
if id := message.id:
message_dict["id"] = id
if isinstance(message, ChatMessage):
return {"role": message.role, **message_dict}
elif isinstance(message, HumanMessage):
return {"role": "user", **message_dict}
elif isinstance(message, AIMessage):
if tool_calls := _get_tool_calls_from_ai_message(message):
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
# If tool calls present, content null value should be None not empty string.
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
return {"role": "assistant", **message_dict}
elif isinstance(message, SystemMessage):
return {"role": "system", **message_dict}
elif isinstance(message, ToolMessage):
return {
"role": "tool",
"tool_call_id": message.tool_call_id,
**message_dict,
}
elif (
isinstance(message, FunctionMessage)
or "function_call" in message.additional_kwargs
):
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
else:
raise ValueError(f"Got unknown message type: {type(message)}")
def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
tool_calls = [
{
"type": "function",
"id": tc["id"],
"function": {
"name": tc["name"],
"arguments": json.dumps(tc["args"]),
},
}
for tc in message.tool_calls
]
invalid_tool_calls = [
{
"type": "function",
"id": tc["id"],
"function": {
"name": tc["name"],
"arguments": tc["args"],
},
}
for tc in message.invalid_tool_calls
]
if tool_calls or invalid_tool_calls:
return tool_calls + invalid_tool_calls
# Get tool calls from additional kwargs if present.
return [
{
k: v
for k, v in tool_call.items() # type: ignore[union-attr]
if k in {"id", "type", "function"}
}
for tool_call in message.additional_kwargs.get("tool_calls", [])
]
def _convert_dict_to_message(_dict: Dict) -> BaseMessage:
role = _dict["role"]
content = _dict.get("content")
content = content if content is not None else ""
if role == "user":
return HumanMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
elif role == "assistant":
additional_kwargs: Dict = {}
tool_calls = []
invalid_tool_calls = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
try:
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
return AIMessage(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
)
else:
return ChatMessage(content=content, role=role)
def _convert_dict_to_message_chunk(
_dict: Mapping[str, Any], default_role: str
) -> BaseMessageChunk:
role = _dict.get("role", default_role)
content = _dict.get("content")
content = content if content is not None else ""
if role == "user":
return HumanMessageChunk(content=content)
elif role == "system":
return SystemMessageChunk(content=content)
elif role == "tool":
return ToolMessageChunk(
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
)
elif role == "assistant":
additional_kwargs: Dict = {}
tool_call_chunks = []
if raw_tool_calls := _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = raw_tool_calls
try:
tool_call_chunks = [
tool_call_chunk(
name=tc["function"].get("name"),
args=tc["function"].get("arguments"),
id=tc.get("id"),
index=tc["index"],
)
for tc in raw_tool_calls
]
except KeyError:
pass
return AIMessageChunk(
content=content,
additional_kwargs=additional_kwargs,
id=_dict.get("id"),
tool_call_chunks=tool_call_chunks,
)
else:
return ChatMessageChunk(content=content, role=role)

2495
libs/partners/databricks/poetry.lock generated Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,99 @@
[tool.poetry]
name = "langchain-databricks"
version = "0.1.0"
description = "An integration package connecting Databricks and LangChain"
authors = []
readme = "README.md"
repository = "https://github.com/langchain-ai/langchain"
license = "MIT"
[tool.poetry.urls]
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/databricks"
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22databricks%3D%3D0%22&expanded=true"
[tool.poetry.dependencies]
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
python = ">=3.8.1,<3.12"
langchain-core = "^0.2.0"
mlflow = ">=2.9"
# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
numpy = [
{version = ">=1.26.0", python = ">=3.12"},
{version = ">=1.24.0", python = "<3.12"},
]
scipy = [
{version = ">=1.11", python = ">=3.12"},
{version = "<2", python = "<3.12"}
]
[tool.poetry.group.test]
optional = true
[tool.poetry.group.test.dependencies]
pytest = "^7.4.3"
pytest-asyncio = "^0.23.2"
pytest-socket = "^0.7.0"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.codespell]
optional = true
[tool.poetry.group.codespell.dependencies]
codespell = "^2.2.6"
[tool.poetry.group.test_integration]
optional = true
[tool.poetry.group.test_integration.dependencies]
[tool.poetry.group.lint]
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.5"
[tool.poetry.group.typing.dependencies]
mypy = "^1.10"
langchain-core = { path = "../../core", develop = true }
[tool.poetry.group.dev]
optional = true
[tool.poetry.group.dev.dependencies]
langchain-core = { path = "../../core", develop = true }
[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # pyflakes
"I", # isort
"T201", # print
]
[tool.mypy]
disallow_untyped_defs = "True"
[tool.coverage.run]
omit = ["tests/*"]
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
# --strict-markers will raise errors on unknown marks.
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
#
# https://docs.pytest.org/en/7.1.x/reference/reference.html
# --strict-config any warnings encountered while parsing the `pytest`
# section of the configuration file raise errors.
#
# https://github.com/tophat/syrupy
addopts = "--strict-markers --strict-config --durations=5"
# Registering custom markers.
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
markers = [
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"

View File

@ -0,0 +1,17 @@
import sys
import traceback
from importlib.machinery import SourceFileLoader
if __name__ == "__main__":
files = sys.argv[1:]
has_failure = False
for file in files:
try:
SourceFileLoader("x", file).load_module()
except Exception:
has_failure = True
print(file) # noqa: T201
traceback.print_exc()
print() # noqa: T201
sys.exit(1 if has_failure else 0)

View File

@ -0,0 +1,27 @@
#!/bin/bash
#
# This script searches for lines starting with "import pydantic" or "from pydantic"
# in tracked files within a Git repository.
#
# Usage: ./scripts/check_pydantic.sh /path/to/repository
# Check if a path argument is provided
if [ $# -ne 1 ]; then
echo "Usage: $0 /path/to/repository"
exit 1
fi
repository_path="$1"
# Search for lines matching the pattern within the specified repository
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
# Check if any matching lines were found
if [ -n "$result" ]; then
echo "ERROR: The following lines need to be updated:"
echo "$result"
echo "Please replace the code with an import from langchain_core.pydantic_v1."
echo "For example, replace 'from pydantic import BaseModel'"
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
exit 1
fi

View File

@ -0,0 +1,18 @@
#!/bin/bash
set -eu
# Initialize a variable to keep track of errors
errors=0
# make sure not importing from langchain, langchain_experimental, or langchain_community
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
# Decide on an exit status based on the errors
if [ "$errors" -gt 0 ]; then
exit 1
else
exit 0
fi

View File

@ -0,0 +1,7 @@
import pytest
@pytest.mark.compile
def test_placeholder() -> None:
"""Used for compiling integration tests without running any real tests."""
pass

View File

@ -0,0 +1,321 @@
"""Test chat model integration."""
import json
from typing import Generator
from unittest import mock
import mlflow # type: ignore # noqa: F401
import pytest
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_databricks.chat_models import (
ChatDatabricks,
_convert_dict_to_message,
_convert_dict_to_message_chunk,
_convert_message_to_dict,
)
_MOCK_CHAT_RESPONSE = {
"id": "chatcmpl_id",
"object": "chat.completion",
"created": 1721875529,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "To calculate the result of 36939 multiplied by 8922.4, "
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
}
_MOCK_STREAM_RESPONSE = [
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "36939"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "x"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "8922.4"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": " = "},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "329,511,111.6"},
"finish_reason": None,
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
},
{
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
"object": "chat.completion.chunk",
"created": 1721877054,
"model": "meta-llama-3.1-70b-instruct-072424",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": ""},
"finish_reason": "stop",
"logprobs": None,
}
],
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
},
]
@pytest.fixture(autouse=True)
def mock_client() -> Generator:
client = mock.MagicMock()
client.predict.return_value = _MOCK_CHAT_RESPONSE
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
yield
@pytest.fixture
def llm() -> ChatDatabricks:
return ChatDatabricks(
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
)
def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
res = llm.invoke(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
]
)
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]
def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
res = llm.stream(
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "36939 * 8922.4"},
]
)
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
class GetWeather(BaseModel):
"""Get the current weather in a given location"""
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
"""Get the current population in a given location"""
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
response = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
assert isinstance(response, AIMessage)
### Test data conversion functions ###
@pytest.mark.parametrize(
("role", "expected_output"),
[
("user", HumanMessage("foo")),
("system", SystemMessage("foo")),
("assistant", AIMessage("foo")),
("any_role", ChatMessage(content="foo", role="any_role")),
],
)
def test_convert_message(role: str, expected_output: BaseMessage) -> None:
message = {"role": role, "content": "foo"}
result = _convert_dict_to_message(message)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == message
def test_convert_message_with_tool_calls() -> None:
ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12"
tool_calls = [
{
"id": ID,
"type": "function",
"function": {
"name": "main__test__python_exec",
"arguments": '{"code": "result = 36939 * 8922.4"}',
},
}
]
message_with_tools = {
"role": "assistant",
"content": None,
"tool_calls": tool_calls,
"id": ID,
}
result = _convert_dict_to_message(message_with_tools)
expected_output = AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls},
id=ID,
tool_calls=[
{
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
"id": ID,
"type": "tool_call",
}
],
)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == message_with_tools
@pytest.mark.parametrize(
("role", "expected_output"),
[
("user", HumanMessageChunk(content="foo")),
("system", SystemMessageChunk(content="foo")),
("assistant", AIMessageChunk(content="foo")),
("any_role", ChatMessageChunk(content="foo", role="any_role")),
],
)
def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None:
delta = {"role": role, "content": "foo"}
result = _convert_dict_to_message_chunk(delta, "default_role")
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == delta
def test_convert_message_chunk_with_tool_calls() -> None:
delta_with_tools = {
"role": "assistant",
"content": None,
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
}
result = _convert_dict_to_message_chunk(delta_with_tools, "role")
expected_output = AIMessageChunk(
content="",
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
id=None,
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
)
assert result == expected_output
def test_convert_tool_message_chunk() -> None:
delta = {
"role": "tool",
"content": "foo",
"tool_call_id": "tool_call_id",
"id": "some_id",
}
result = _convert_dict_to_message_chunk(delta, "default_role")
expected_output = ToolMessageChunk(
content="foo", id="some_id", tool_call_id="tool_call_id"
)
assert result == expected_output
# convert back
dict_result = _convert_message_to_dict(result)
assert dict_result == delta
def test_convert_message_to_dict_function() -> None:
with pytest.raises(ValueError, match="Function messages are not supported"):
_convert_message_to_dict(FunctionMessage(content="", name="name"))

View File

@ -0,0 +1,10 @@
from langchain_databricks import __all__
EXPECTED_ALL = [
"ChatDatabricks",
"__version__",
]
def test_all_imports() -> None:
assert sorted(EXPECTED_ALL) == sorted(__all__)