mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 16:39:52 +00:00
databricks: Add partner package directory and ChatDatabricks implementation (#25430)
### Summary
Create `langchain-databricks` as a new partner packages. This PR does
not migrate all existing Databricks integration, but the package will
eventually contain:
* `ChatDatabricks` (implemented in this PR)
* `DatabricksVectorSearch`
* `DatabricksEmbeddings`
* ~`UCFunctionToolkit`~ (will be done after UC SDK work which
drastically simplify implementation)
Also, this PR does not add integration tests yet. This will be added
once the Databricks test workspace is ready.
Tagging @efriis as POC
### Tracker
[✍️] Create a package and imgrate ChatDatabricks
[ ] Migrate DatabricksVectorSearch, DatabricksEmbeddings, and their docs
~[ ] Migrate UCFunctionToolkit and its doc~
[ ] Add provider document and update README.md
[ ] Add integration tests and set up secrets (after moved to an external
package)
[ ] Add deprecation note to the community implementations.
---------
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
fb1d67edf6
commit
3981d736df
@ -18,7 +18,7 @@ for dir; do \
|
||||
if find "$$dir" -maxdepth 1 -type f \( -name "pyproject.toml" -o -name "setup.py" \) | grep -q .; then \
|
||||
echo "$$dir"; \
|
||||
fi \
|
||||
done' sh {} + | grep -vE "airbyte|ibm|couchbase" | tr '\n' ' ')
|
||||
done' sh {} + | grep -vE "airbyte|ibm|couchbase|databricks" | tr '\n' ' ')
|
||||
|
||||
PORT ?= 3001
|
||||
|
||||
|
@ -31,7 +31,7 @@
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: |\n",
|
||||
"| [ChatDatabricks](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.databricks.ChatDatabricks.html) | [langchain-community](https://api.python.langchain.com/en/latest/community_api_reference.html) | ❌ | beta |  |  |\n",
|
||||
"| [ChatDatabricks](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.databricks.ChatDatabricks.html) | [langchain-databricks](https://api.python.langchain.com/en/latest/databricks_api_reference.html) | ❌ | beta |  |  |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling/) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
@ -99,7 +99,7 @@
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The LangChain Databricks integration lives in the `langchain-community` package. Also, `mlflow >= 2.9 ` is required to run the code in this notebook."
|
||||
"The LangChain Databricks integration lives in the `langchain-databricks` package."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -108,7 +108,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain-community mlflow>=2.9.0"
|
||||
"%pip install -qU langchain-databricks"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -133,7 +133,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models import ChatDatabricks\n",
|
||||
"from langchain_databricks import ChatDatabricks\n",
|
||||
"\n",
|
||||
"chat_model = ChatDatabricks(\n",
|
||||
" endpoint=\"databricks-dbrx-instruct\",\n",
|
||||
@ -245,9 +245,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation (streaming)\n",
|
||||
"\n",
|
||||
"`ChatDatabricks` supports streaming response by `stream` method since `langchain-community>=0.2.1`."
|
||||
"## Invocation (streaming)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -299,7 +297,7 @@
|
||||
"* An LLM was registered and deployed to [a Databricks serving endpoint](https://docs.databricks.com/machine-learning/model-serving/index.html) via MLflow. The endpoint must have OpenAI-compatible chat input/output format ([reference](https://mlflow.org/docs/latest/llms/deployments/index.html#chat))\n",
|
||||
"* You have [\"Can Query\" permission](https://docs.databricks.com/security/auth-authz/access-control/serving-endpoint-acl.html) to the endpoint.\n",
|
||||
"\n",
|
||||
"Once the endpoint is ready, the usage pattern is completely same as Foundation Models."
|
||||
"Once the endpoint is ready, the usage pattern is identical to that of Foundation Models."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -332,7 +330,7 @@
|
||||
"\n",
|
||||
"First, create a new Databricks serving endpoint that proxies requests to the target external model. The endpoint creation should be fairy quick for proxying external models.\n",
|
||||
"\n",
|
||||
"This requires registering OpenAI API Key in Databricks secret manager with the following comment:\n",
|
||||
"This requires registering your OpenAI API Key within the Databricks secret manager as follows:\n",
|
||||
"```sh\n",
|
||||
"# Replace `<scope>` with your scope\n",
|
||||
"databricks secrets create-scope <scope>\n",
|
||||
@ -417,8 +415,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.databricks import ChatDatabricks\n",
|
||||
"\n",
|
||||
"llm = ChatDatabricks(endpoint=\"databricks-meta-llama-3-70b-instruct\")\n",
|
||||
"tools = [\n",
|
||||
" {\n",
|
||||
@ -461,7 +457,7 @@
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all ChatDatabricks features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.ChatDatabricks.html"
|
||||
"For detailed documentation of all ChatDatabricks features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_databricks.chat_models.ChatDatabricks.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
1
libs/partners/databricks/.gitignore
vendored
Normal file
1
libs/partners/databricks/.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
__pycache__
|
21
libs/partners/databricks/LICENSE
Normal file
21
libs/partners/databricks/LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
62
libs/partners/databricks/Makefile
Normal file
62
libs/partners/databricks/Makefile
Normal file
@ -0,0 +1,62 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
integration_test integration_tests: TEST_FILE = tests/integration_tests/
|
||||
|
||||
|
||||
# unit tests are run with the --disable-socket flag to prevent network calls
|
||||
test tests:
|
||||
poetry run pytest --disable-socket --allow-unix-socket $(TEST_FILE)
|
||||
|
||||
# integration tests are run without the --disable-socket flag to allow network calls
|
||||
integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/databricks --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_databricks
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff check .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff check --select I $(PYTHON_FILES)
|
||||
mkdir -p $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff check --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_databricks -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
24
libs/partners/databricks/README.md
Normal file
24
libs/partners/databricks/README.md
Normal file
@ -0,0 +1,24 @@
|
||||
# langchain-databricks
|
||||
|
||||
This package contains the LangChain integration with Databricks
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install -U langchain-databricks
|
||||
```
|
||||
|
||||
And you should configure credentials by setting the following environment variables:
|
||||
|
||||
* TODO: fill this out
|
||||
|
||||
## Chat Models
|
||||
|
||||
`ChatDatabricks` class exposes chat models from Databricks.
|
||||
|
||||
```python
|
||||
from langchain_databricks import ChatDatabricks
|
||||
|
||||
llm = ChatDatabricks()
|
||||
llm.invoke("Sing a ballad of LangChain.")
|
||||
```
|
15
libs/partners/databricks/langchain_databricks/__init__.py
Normal file
15
libs/partners/databricks/langchain_databricks/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from importlib import metadata
|
||||
|
||||
from langchain_databricks.chat_models import ChatDatabricks
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
except metadata.PackageNotFoundError:
|
||||
# Case where package metadata is not available.
|
||||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
__all__ = [
|
||||
"ChatDatabricks",
|
||||
"__version__",
|
||||
]
|
573
libs/partners/databricks/langchain_databricks/chat_models.py
Normal file
573
libs/partners/databricks/langchain_databricks/chat_models.py
Normal file
@ -0,0 +1,573 @@
|
||||
"""Databricks chat models."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.language_models.base import LanguageModelInput
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessage,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import tool_call_chunk
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
make_invalid_tool_call,
|
||||
parse_tool_call,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
)
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDatabricks(BaseChatModel):
|
||||
"""Databricks chat model integration.
|
||||
|
||||
Setup:
|
||||
Install ``langchain-databricks``.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install -U langchain-databricks
|
||||
|
||||
If you are outside Databricks, set the Databricks workspace hostname and personal access token to environment variables:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
export DATABRICKS_HOSTNAME="https://your-databricks-workspace"
|
||||
export DATABRICKS_TOKEN="your-personal-access-token"
|
||||
|
||||
Key init args — completion params:
|
||||
endpoint: str
|
||||
Name of Databricks Model Serving endpoint to query.
|
||||
target_uri: str
|
||||
The target URI to use. Defaults to ``databricks``.
|
||||
temperature: float
|
||||
Sampling temperature. Higher values make the model more creative.
|
||||
n: Optional[int]
|
||||
The number of completion choices to generate.
|
||||
stop: Optional[List[str]]
|
||||
List of strings to stop generation at.
|
||||
max_tokens: Optional[int]
|
||||
Max number of tokens to generate.
|
||||
extra_params: Optional[Dict[str, Any]]
|
||||
Any extra parameters to pass to the endpoint.
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_databricks import ChatDatabricks
|
||||
llm = ChatDatabricks(
|
||||
endpoint="databricks-meta-llama-3-1-405b-instruct",
|
||||
temperature=0,
|
||||
max_tokens=500,
|
||||
)
|
||||
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
|
||||
messages = [
|
||||
("system", "You are a helpful translator. Translate the user sentence to French."),
|
||||
("human", "I love programming."),
|
||||
]
|
||||
llm.invoke(messages)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'prompt_tokens': 32,
|
||||
'completion_tokens': 9,
|
||||
'total_tokens': 41
|
||||
},
|
||||
id='run-64eebbdd-88a8-4a25-b508-21e9a5f146c5-0'
|
||||
)
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in llm.stream(messages):
|
||||
print(chunk)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
content='J' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content="'" id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ad' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ore' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content=' la' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content=' programm' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='ation' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='.' id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
content='' response_metadata={'finish_reason': 'stop'} id='run-609b8f47-e580-4691-9ee4-e2109f53155e'
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
stream = llm.stream(messages)
|
||||
full = next(stream)
|
||||
for chunk in stream:
|
||||
full += chunk
|
||||
full
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessageChunk(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'finish_reason': 'stop'
|
||||
},
|
||||
id='run-4cef851f-6223-424f-ad26-4a54e5852aa5'
|
||||
)
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
await llm.ainvoke(messages)
|
||||
|
||||
# stream:
|
||||
# async for chunk in llm.astream(messages)
|
||||
|
||||
# batch:
|
||||
# await llm.abatch([messages])
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
AIMessage(
|
||||
content="J'adore la programmation.",
|
||||
response_metadata={
|
||||
'prompt_tokens': 32,
|
||||
'completion_tokens': 9,
|
||||
'total_tokens': 41
|
||||
},
|
||||
id='run-e4bb043e-772b-4e1d-9f98-77ccc00c0271-0'
|
||||
)
|
||||
|
||||
Tool calling:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
class GetWeather(BaseModel):
|
||||
'''Get the current weather in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
'''Get the current population in a given location'''
|
||||
|
||||
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
|
||||
ai_msg.tool_calls
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
[
|
||||
{
|
||||
'name': 'GetWeather',
|
||||
'args': {
|
||||
'location': 'Los Angeles, CA'
|
||||
},
|
||||
'id': 'call_ea0a6004-8e64-4ae8-a192-a40e295bfa24',
|
||||
'type': 'tool_call'
|
||||
}
|
||||
]
|
||||
|
||||
To use tool calls, your model endpoint must support ``tools`` parameter. See [Function calling on Databricks](https://python.langchain.com/v0.2/docs/integrations/chat/databricks/#function-calling-on-databricks) for more information.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
endpoint: str
|
||||
"""Name of Databricks Model Serving endpoint to query."""
|
||||
target_uri: str = "databricks"
|
||||
"""The target URI to use. Defaults to ``databricks``."""
|
||||
temperature: float = 0.0
|
||||
"""Sampling temperature. Higher values make the model more creative."""
|
||||
n: int = 1
|
||||
"""The number of completion choices to generate."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""List of strings to stop generation at."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
extra_params: dict = Field(default_factory=dict)
|
||||
"""Any extra parameters to pass to the endpoint."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._validate_uri()
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client # type: ignore
|
||||
|
||||
self._client = get_deploy_client(self.target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. Please run `pip install mlflow` to "
|
||||
"install required dependencies."
|
||||
) from e
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
|
||||
if urlparse(self.target_uri).scheme != "databricks":
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"target_uri": self.target_uri,
|
||||
"endpoint": self.endpoint,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_params": self.extra_params,
|
||||
}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return self._convert_response_to_chat_result(resp)
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
data: Dict[str, Any] = {
|
||||
"messages": [_convert_message_to_dict(msg) for msg in messages],
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
**self.extra_params,
|
||||
**kwargs,
|
||||
}
|
||||
if stop := self.stop or stop:
|
||||
data["stop"] = stop
|
||||
if self.max_tokens is not None:
|
||||
data["max_tokens"] = self.max_tokens
|
||||
|
||||
return data
|
||||
|
||||
def _convert_response_to_chat_result(
|
||||
self, response: Mapping[str, Any]
|
||||
) -> ChatResult:
|
||||
generations = [
|
||||
ChatGeneration(
|
||||
message=_convert_dict_to_message(choice["message"]),
|
||||
generation_info=choice.get("usage", {}),
|
||||
)
|
||||
for choice in response["choices"]
|
||||
]
|
||||
usage = response.get("usage", {})
|
||||
return ChatResult(generations=generations, llm_output=usage)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
data = self._prepare_inputs(messages, stop, **kwargs)
|
||||
first_chunk_role = None
|
||||
for chunk in self._client.predict_stream(endpoint=self.endpoint, inputs=data):
|
||||
if chunk["choices"]:
|
||||
choice = chunk["choices"][0]
|
||||
|
||||
chunk_delta = choice["delta"]
|
||||
if first_chunk_role is None:
|
||||
first_chunk_role = chunk_delta.get("role")
|
||||
|
||||
chunk_message = _convert_dict_to_message_chunk(
|
||||
chunk_delta, first_chunk_role
|
||||
)
|
||||
|
||||
generation_info = {}
|
||||
if finish_reason := choice.get("finish_reason"):
|
||||
generation_info["finish_reason"] = finish_reason
|
||||
if logprobs := choice.get("logprobs"):
|
||||
generation_info["logprobs"] = logprobs
|
||||
|
||||
chunk = ChatGenerationChunk(
|
||||
message=chunk_message, generation_info=generation_info or None
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
chunk.text, chunk=chunk, logprobs=logprobs
|
||||
)
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Handle the case where choices are empty if needed
|
||||
continue
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
*,
|
||||
tool_choice: Optional[
|
||||
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
||||
] = None,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
|
||||
Assumes model is compatible with OpenAI tool-calling API.
|
||||
|
||||
Args:
|
||||
tools: A list of tool definitions to bind to this chat model.
|
||||
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
|
||||
models, callables, and BaseTools will be automatically converted to
|
||||
their schema dictionary representation.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Options are:
|
||||
name of the tool (str): calls corresponding tool;
|
||||
"auto": automatically selects a tool (including no tool);
|
||||
"none": model does not generate any tool calls and instead must
|
||||
generate a standard assistant message;
|
||||
"required": the model picks the most relevant tool in tools and
|
||||
must generate a tool call;
|
||||
|
||||
or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
**kwargs: Any additional parameters to pass to the
|
||||
:class:`~langchain.runnable.Runnable` constructor.
|
||||
"""
|
||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
if tool_choice:
|
||||
if isinstance(tool_choice, str):
|
||||
# tool_choice is a tool/function name
|
||||
if tool_choice not in ("auto", "none", "required"):
|
||||
tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": tool_choice},
|
||||
}
|
||||
elif isinstance(tool_choice, dict):
|
||||
tool_names = [
|
||||
formatted_tool["function"]["name"]
|
||||
for formatted_tool in formatted_tools
|
||||
]
|
||||
if not any(
|
||||
tool_name == tool_choice["function"]["name"]
|
||||
for tool_name in tool_names
|
||||
):
|
||||
raise ValueError(
|
||||
f"Tool choice {tool_choice} was specified, but the only "
|
||||
f"provided tools were {tool_names}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized tool_choice type. Expected str, bool or dict. "
|
||||
f"Received: {tool_choice}"
|
||||
)
|
||||
kwargs["tool_choice"] = tool_choice
|
||||
return super().bind(tools=formatted_tools, **kwargs)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "chat-databricks"
|
||||
|
||||
|
||||
### Conversion function to convert Pydantic models to dictionaries and vice versa. ###
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"content": message.content}
|
||||
|
||||
# OpenAI supports "name" field in messages.
|
||||
if (name := message.name or message.additional_kwargs.get("name")) is not None:
|
||||
message_dict["name"] = name
|
||||
|
||||
if id := message.id:
|
||||
message_dict["id"] = id
|
||||
|
||||
if isinstance(message, ChatMessage):
|
||||
return {"role": message.role, **message_dict}
|
||||
elif isinstance(message, HumanMessage):
|
||||
return {"role": "user", **message_dict}
|
||||
elif isinstance(message, AIMessage):
|
||||
if tool_calls := _get_tool_calls_from_ai_message(message):
|
||||
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
|
||||
# If tool calls present, content null value should be None not empty string.
|
||||
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
|
||||
return {"role": "assistant", **message_dict}
|
||||
elif isinstance(message, SystemMessage):
|
||||
return {"role": "system", **message_dict}
|
||||
elif isinstance(message, ToolMessage):
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": message.tool_call_id,
|
||||
**message_dict,
|
||||
}
|
||||
elif (
|
||||
isinstance(message, FunctionMessage)
|
||||
or "function_call" in message.additional_kwargs
|
||||
):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by Databricks. Please"
|
||||
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown message type: {type(message)}")
|
||||
|
||||
|
||||
def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
|
||||
tool_calls = [
|
||||
{
|
||||
"type": "function",
|
||||
"id": tc["id"],
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": json.dumps(tc["args"]),
|
||||
},
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
]
|
||||
|
||||
invalid_tool_calls = [
|
||||
{
|
||||
"type": "function",
|
||||
"id": tc["id"],
|
||||
"function": {
|
||||
"name": tc["name"],
|
||||
"arguments": tc["args"],
|
||||
},
|
||||
}
|
||||
for tc in message.invalid_tool_calls
|
||||
]
|
||||
|
||||
if tool_calls or invalid_tool_calls:
|
||||
return tool_calls + invalid_tool_calls
|
||||
|
||||
# Get tool calls from additional kwargs if present.
|
||||
return [
|
||||
{
|
||||
k: v
|
||||
for k, v in tool_call.items() # type: ignore[union-attr]
|
||||
if k in {"id", "type", "function"}
|
||||
}
|
||||
for tool_call in message.additional_kwargs.get("tool_calls", [])
|
||||
]
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Dict) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict.get("content")
|
||||
content = content if content is not None else ""
|
||||
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
elif role == "assistant":
|
||||
additional_kwargs: Dict = {}
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(parse_tool_call(raw_tool_call, return_id=True))
|
||||
except Exception as e:
|
||||
invalid_tool_calls.append(
|
||||
make_invalid_tool_call(raw_tool_call, str(e))
|
||||
)
|
||||
return AIMessage(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_calls=tool_calls,
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
|
||||
def _convert_dict_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_role: str
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role", default_role)
|
||||
content = _dict.get("content")
|
||||
content = content if content is not None else ""
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "tool":
|
||||
return ToolMessageChunk(
|
||||
content=content, tool_call_id=_dict["tool_call_id"], id=_dict.get("id")
|
||||
)
|
||||
elif role == "assistant":
|
||||
additional_kwargs: Dict = {}
|
||||
tool_call_chunks = []
|
||||
if raw_tool_calls := _dict.get("tool_calls"):
|
||||
additional_kwargs["tool_calls"] = raw_tool_calls
|
||||
try:
|
||||
tool_call_chunks = [
|
||||
tool_call_chunk(
|
||||
name=tc["function"].get("name"),
|
||||
args=tc["function"].get("arguments"),
|
||||
id=tc.get("id"),
|
||||
index=tc["index"],
|
||||
)
|
||||
for tc in raw_tool_calls
|
||||
]
|
||||
except KeyError:
|
||||
pass
|
||||
return AIMessageChunk(
|
||||
content=content,
|
||||
additional_kwargs=additional_kwargs,
|
||||
id=_dict.get("id"),
|
||||
tool_call_chunks=tool_call_chunks,
|
||||
)
|
||||
else:
|
||||
return ChatMessageChunk(content=content, role=role)
|
2495
libs/partners/databricks/poetry.lock
generated
Normal file
2495
libs/partners/databricks/poetry.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
99
libs/partners/databricks/pyproject.toml
Normal file
99
libs/partners/databricks/pyproject.toml
Normal file
@ -0,0 +1,99 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-databricks"
|
||||
version = "0.1.0"
|
||||
description = "An integration package connecting Databricks and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/databricks"
|
||||
"Release Notes" = "https://github.com/langchain-ai/langchain/releases?q=tag%3A%22databricks%3D%3D0%22&expanded=true"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
# TODO: Replace <3.12 to <4.0 once https://github.com/mlflow/mlflow/commit/04370119fcc1b2ccdbcd9a50198ab00566d58cd2 is released
|
||||
python = ">=3.8.1,<3.12"
|
||||
langchain-core = "^0.2.0"
|
||||
mlflow = ">=2.9"
|
||||
|
||||
# MLflow depends on following libraries, which require different version for Python 3.8 vs 3.12
|
||||
numpy = [
|
||||
{version = ">=1.26.0", python = ">=3.12"},
|
||||
{version = ">=1.24.0", python = "<3.12"},
|
||||
]
|
||||
scipy = [
|
||||
{version = ">=1.11", python = ">=3.12"},
|
||||
{version = "<2", python = "<3.12"}
|
||||
]
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.4.3"
|
||||
pytest-asyncio = "^0.23.2"
|
||||
pytest-socket = "^0.7.0"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.6"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^1.10"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"T201", # print
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
addopts = "--strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
17
libs/partners/databricks/scripts/check_imports.py
Normal file
17
libs/partners/databricks/scripts/check_imports.py
Normal file
@ -0,0 +1,17 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_failure = True
|
||||
print(file) # noqa: T201
|
||||
traceback.print_exc()
|
||||
print() # noqa: T201
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
27
libs/partners/databricks/scripts/check_pydantic.sh
Executable file
27
libs/partners/databricks/scripts/check_pydantic.sh
Executable file
@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
18
libs/partners/databricks/scripts/lint_imports.sh
Executable file
18
libs/partners/databricks/scripts/lint_imports.sh
Executable file
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain, langchain_experimental, or langchain_community
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_community\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
0
libs/partners/databricks/tests/__init__.py
Normal file
0
libs/partners/databricks/tests/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
321
libs/partners/databricks/tests/unit_tests/test_chat_models.py
Normal file
321
libs/partners/databricks/tests/unit_tests/test_chat_models.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""Test chat model integration."""
|
||||
|
||||
import json
|
||||
from typing import Generator
|
||||
from unittest import mock
|
||||
|
||||
import mlflow # type: ignore # noqa: F401
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
ToolMessageChunk,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_databricks.chat_models import (
|
||||
ChatDatabricks,
|
||||
_convert_dict_to_message,
|
||||
_convert_dict_to_message_chunk,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
_MOCK_CHAT_RESPONSE = {
|
||||
"id": "chatcmpl_id",
|
||||
"object": "chat.completion",
|
||||
"created": 1721875529,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "To calculate the result of 36939 multiplied by 8922.4, "
|
||||
"I get:\n\n36939 x 8922.4 = 329,511,111.6",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
}
|
||||
|
||||
_MOCK_STREAM_RESPONSE = [
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "36939"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 20, "total_tokens": 50},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "x"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 22, "total_tokens": 52},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "8922.4"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 24, "total_tokens": 54},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": " = "},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 28, "total_tokens": 58},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": "329,511,111.6"},
|
||||
"finish_reason": None,
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 30, "total_tokens": 60},
|
||||
},
|
||||
{
|
||||
"id": "chatcmpl_bb1fce87-f14e-4ae1-ac22-89facc74898a",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1721877054,
|
||||
"model": "meta-llama-3.1-70b-instruct-072424",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"role": "assistant", "content": ""},
|
||||
"finish_reason": "stop",
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 30, "completion_tokens": 36, "total_tokens": 66},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_client() -> Generator:
|
||||
client = mock.MagicMock()
|
||||
client.predict.return_value = _MOCK_CHAT_RESPONSE
|
||||
client.predict_stream.return_value = _MOCK_STREAM_RESPONSE
|
||||
with mock.patch("mlflow.deployments.get_deploy_client", return_value=client):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> ChatDatabricks:
|
||||
return ChatDatabricks(
|
||||
endpoint="databricks-meta-llama-3-70b-instruct", target_uri="databricks"
|
||||
)
|
||||
|
||||
|
||||
def test_chat_mlflow_predict(llm: ChatDatabricks) -> None:
|
||||
res = llm.invoke(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "36939 * 8922.4"},
|
||||
]
|
||||
)
|
||||
assert res.content == _MOCK_CHAT_RESPONSE["choices"][0]["message"]["content"] # type: ignore[index]
|
||||
|
||||
|
||||
def test_chat_mlflow_stream(llm: ChatDatabricks) -> None:
|
||||
res = llm.stream(
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "36939 * 8922.4"},
|
||||
]
|
||||
)
|
||||
for chunk, expected in zip(res, _MOCK_STREAM_RESPONSE):
|
||||
assert chunk.content == expected["choices"][0]["delta"]["content"] # type: ignore[index]
|
||||
|
||||
|
||||
def test_chat_mlflow_bind_tools(llm: ChatDatabricks) -> None:
|
||||
class GetWeather(BaseModel):
|
||||
"""Get the current weather in a given location"""
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
class GetPopulation(BaseModel):
|
||||
"""Get the current population in a given location"""
|
||||
|
||||
location: str = Field(
|
||||
..., description="The city and state, e.g. San Francisco, CA"
|
||||
)
|
||||
|
||||
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
|
||||
response = llm_with_tools.invoke(
|
||||
"Which city is hotter today and which is bigger: LA or NY?"
|
||||
)
|
||||
assert isinstance(response, AIMessage)
|
||||
|
||||
|
||||
### Test data conversion functions ###
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_output"),
|
||||
[
|
||||
("user", HumanMessage("foo")),
|
||||
("system", SystemMessage("foo")),
|
||||
("assistant", AIMessage("foo")),
|
||||
("any_role", ChatMessage(content="foo", role="any_role")),
|
||||
],
|
||||
)
|
||||
def test_convert_message(role: str, expected_output: BaseMessage) -> None:
|
||||
message = {"role": role, "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == message
|
||||
|
||||
|
||||
def test_convert_message_with_tool_calls() -> None:
|
||||
ID = "call_fb5f5e1a-bac0-4422-95e9-d06e6022ad12"
|
||||
tool_calls = [
|
||||
{
|
||||
"id": ID,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "main__test__python_exec",
|
||||
"arguments": '{"code": "result = 36939 * 8922.4"}',
|
||||
},
|
||||
}
|
||||
]
|
||||
message_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": tool_calls,
|
||||
"id": ID,
|
||||
}
|
||||
result = _convert_dict_to_message(message_with_tools)
|
||||
expected_output = AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls},
|
||||
id=ID,
|
||||
tool_calls=[
|
||||
{
|
||||
"name": tool_calls[0]["function"]["name"], # type: ignore[index]
|
||||
"args": json.loads(tool_calls[0]["function"]["arguments"]), # type: ignore[index]
|
||||
"id": ID,
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == message_with_tools
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("role", "expected_output"),
|
||||
[
|
||||
("user", HumanMessageChunk(content="foo")),
|
||||
("system", SystemMessageChunk(content="foo")),
|
||||
("assistant", AIMessageChunk(content="foo")),
|
||||
("any_role", ChatMessageChunk(content="foo", role="any_role")),
|
||||
],
|
||||
)
|
||||
def test_convert_message_chunk(role: str, expected_output: BaseMessage) -> None:
|
||||
delta = {"role": role, "content": "foo"}
|
||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == delta
|
||||
|
||||
|
||||
def test_convert_message_chunk_with_tool_calls() -> None:
|
||||
delta_with_tools = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{"index": 0, "function": {"arguments": " }"}}],
|
||||
}
|
||||
result = _convert_dict_to_message_chunk(delta_with_tools, "role")
|
||||
expected_output = AIMessageChunk(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": delta_with_tools["tool_calls"]},
|
||||
id=None,
|
||||
tool_call_chunks=[ToolCallChunk(name=None, args=" }", id=None, index=0)],
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_convert_tool_message_chunk() -> None:
|
||||
delta = {
|
||||
"role": "tool",
|
||||
"content": "foo",
|
||||
"tool_call_id": "tool_call_id",
|
||||
"id": "some_id",
|
||||
}
|
||||
result = _convert_dict_to_message_chunk(delta, "default_role")
|
||||
expected_output = ToolMessageChunk(
|
||||
content="foo", id="some_id", tool_call_id="tool_call_id"
|
||||
)
|
||||
assert result == expected_output
|
||||
|
||||
# convert back
|
||||
dict_result = _convert_message_to_dict(result)
|
||||
assert dict_result == delta
|
||||
|
||||
|
||||
def test_convert_message_to_dict_function() -> None:
|
||||
with pytest.raises(ValueError, match="Function messages are not supported"):
|
||||
_convert_message_to_dict(FunctionMessage(content="", name="name"))
|
10
libs/partners/databricks/tests/unit_tests/test_imports.py
Normal file
10
libs/partners/databricks/tests/unit_tests/test_imports.py
Normal file
@ -0,0 +1,10 @@
|
||||
from langchain_databricks import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatDatabricks",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
Loading…
Reference in New Issue
Block a user