langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699)

## Description

Related to https://github.com/mlflow/mlflow/pull/10420. MLflow AI
gateway will be deprecated and replaced by the `mlflow.deployments`
module. Happy to split this PR if it's too large.

```
pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain
```

## Dependencies

Install mlflow from https://github.com/mlflow/mlflow/pull/10420:

```
pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge
```

## Testing plan

The following code works fine on local and databricks:

<details><summary>Click</summary>
<p>

```python
"""
Setup
-----
mlflow deployments start-server --config-path examples/gateway/openai/config.yaml
databricks secrets create-scope <scope>
databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY

Run
---
python /path/to/this/file.py secrets/<scope>/openai-api-key
"""
from langchain.chat_models import ChatMlflow, ChatDatabricks
from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings
from langchain.llms import Databricks, Mlflow
from langchain.schema.messages import HumanMessage
from langchain.chains.loading import load_chain
from mlflow.deployments import get_deploy_client
import uuid
import sys
import tempfile
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate

###############################
# MLflow
###############################
chat = ChatMlflow(
    target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings")
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])

llm = Mlflow(
    target_uri="http://127.0.0.1:5000",
    endpoint="completions",
    params={"temperature": 0.1},
)
print(llm("I am"))

llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

###############################
# Databricks
###############################
secret = sys.argv[1]
client = get_deploy_client("databricks")

# External - chat
name = f"chat-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-4",
                    "provider": "openai",
                    "task": "llm/v1/chat",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    chat = ChatDatabricks(
        target_uri="databricks", endpoint=name, params={"temperature": 0.1}
    )
    print(chat([HumanMessage(content="hello")]))
finally:
    client.delete_endpoint(endpoint=name)

# External - embeddings
name = f"embeddings-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "text-embedding-ada-002",
                    "provider": "openai",
                    "task": "llm/v1/embeddings",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name)
    print(embeddings.embed_query("hello")[:3])
    print(embeddings.embed_documents(["hello", "world"])[0][:3])
finally:
    client.delete_endpoint(endpoint=name)

# External - completions
name = f"completions-{uuid.uuid4()}"
client.create_endpoint(
    name=name,
    config={
        "served_entities": [
            {
                "name": "test",
                "external_model": {
                    "name": "gpt-3.5-turbo-instruct",
                    "provider": "openai",
                    "task": "llm/v1/completions",
                    "openai_config": {
                        "openai_api_key": "{{" + secret + "}}",
                    },
                },
            }
        ],
    },
)
try:
    llm = Databricks(
        endpoint_name=name,
        model_kwargs={"temperature": 0.1},
    )
    print(llm("I am"))
finally:
    client.delete_endpoint(endpoint=name)


# Foundation model - chat
chat = ChatDatabricks(
    endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1}
)
print(chat([HumanMessage(content="hello")]))

# Foundation model - embeddings
embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en")
print(embeddings.embed_query("hello")[:3])

# Foundation model - completions
llm = Databricks(
    endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1}
)
print(llm("hello"))
llm_chain = LLMChain(
    llm=llm,
    prompt=PromptTemplate(
        input_variables=["adjective"],
        template="Tell me a {adjective} joke",
    ),
)
print(llm_chain.run(adjective="funny"))

# serialization/deserialization
with tempfile.TemporaryDirectory() as tmpdir:
    print(tmpdir)
    path = f"{tmpdir}/llm.yaml"
    llm_chain.save(path)
    loaded_chain = load_chain(path)
    print(loaded_chain("funny"))

```

Output:

```
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context?
Sure, here's a classic one for you:

Why don't scientists trust atoms?

Because they make up everything!
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad
{'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"}
content='Hello! How can I assist you today?'
[-0.025058426, -0.01938856, -0.027781019]
[-0.025058426, -0.01938856, -0.027781019]
 a 23 year old female and I am currently studying for my master's degree
content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?"
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]
[0.051055908203125, 0.007221221923828125, 0.003879547119140625]

hello back
 Well, I don't really know many jokes, but I do know this funny story...
/var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex
{'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."}
```

</p>
</details>

The existing workflow doesn't break:

<details><summary>click</summary>
<p>

```python
import uuid

import mlflow
from mlflow.models import ModelSignature
from mlflow.types.schema import ColSpec, Schema


class MyModel(mlflow.pyfunc.PythonModel):
    def predict(self, context, model_input):
        return str(uuid.uuid4())


with mlflow.start_run():
    mlflow.pyfunc.log_model(
        "model",
        python_model=MyModel(),
        pip_requirements=["mlflow==2.8.1", "cloudpickle<3"],
        signature=ModelSignature(
            inputs=Schema(
                [
                    ColSpec("string", "prompt"),
                    ColSpec("string", "stop"),
                ]
            ),
            outputs=Schema(
                [
                    ColSpec(name=None, type="string"),
                ]
            ),
        ),
        registered_model_name=f"lang-{uuid.uuid4()}",
    )

# Manually create a serving endpoint with the registered model and run
from langchain.llms import Databricks

llm = Databricks(endpoint_name="<name>")
llm("hello")  # 9d0b2491-3d13-487c-bc02-1287f06ecae7
```

</p>
</details> 

## Follow-up tasks

(This PR is too large. I'll file a separate one for follow-up tasks.)

- Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and
`docs/docs/integrations/providers/databricks.md`.

---------

Signed-off-by: harupy <17039389+harupy@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harutaka Kawamura 2023-12-01 08:06:58 +09:00 committed by GitHub
parent dc31714ec5
commit 0d08a692a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 666 additions and 22 deletions

View File

@ -24,6 +24,7 @@ from langchain.chat_models.baichuan import ChatBaichuan
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain.chat_models.bedrock import BedrockChat
from langchain.chat_models.cohere import ChatCohere
from langchain.chat_models.databricks import ChatDatabricks
from langchain.chat_models.ernie import ErnieBotChat
from langchain.chat_models.everlyai import ChatEverlyAI
from langchain.chat_models.fake import FakeListChatModel
@ -37,6 +38,7 @@ from langchain.chat_models.jinachat import JinaChat
from langchain.chat_models.konko import ChatKonko
from langchain.chat_models.litellm import ChatLiteLLM
from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow import ChatMlflow
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
@ -52,10 +54,12 @@ __all__ = [
"AzureChatOpenAI",
"FakeListChatModel",
"PromptLayerChatOpenAI",
"ChatDatabricks",
"ChatEverlyAI",
"ChatAnthropic",
"ChatCohere",
"ChatGooglePalm",
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatOllama",
"ChatVertexAI",

View File

@ -0,0 +1,46 @@
import logging
from urllib.parse import urlparse
from langchain.chat_models.mlflow import ChatMlflow
logger = logging.getLogger(__name__)
class ChatDatabricks(ChatMlflow):
"""`Databricks` chat models API.
To use, you should have the ``mlflow`` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
Example:
.. code-block:: python
from langchain.chat_models import ChatDatabricks
chat = ChatDatabricks(
target_uri="databricks",
endpoint="chat",
temperature-0.1,
)
"""
target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "databricks-chat"
@property
def _mlflow_extras(self) -> str:
return ""
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)

View File

@ -0,0 +1,217 @@
import asyncio
import logging
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
BaseMessage,
ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import (
Field,
PrivateAttr,
)
logger = logging.getLogger(__name__)
class ChatMlflow(BaseChatModel):
"""`MLflow` chat models API.
To use, you should have the `mlflow[genai]` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
Example:
.. code-block:: python
from langchain.chat_models import ChatMlflow
chat = ChatMlflow(
target_uri="http://localhost:5000",
endpoint="chat",
temperature-0.1,
)
"""
endpoint: str
"""The endpoint to use."""
target_uri: str
"""The target URI to use."""
temperature: float = 0.0
"""The sampling temperature."""
n: int = 1
"""The number of completion choices to generate."""
stop: Optional[List[str]] = None
"""The stop sequence."""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""
extra_params: dict = Field(default_factory=dict)
"""Any extra parameters to pass to the endpoint."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client
self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
"required dependencies."
) from e
@property
def _mlflow_extras(self) -> str:
return "[genai]"
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
allowed = ["http", "https", "databricks"]
if urlparse(self.target_uri).scheme not in allowed:
raise ValueError(
f"Invalid target URI: {self.target_uri}. "
f"The scheme must be one of {allowed}."
)
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"target_uri": self.target_uri,
"endpoint": self.endpoint,
"temperature": self.temperature,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens,
"extra_params": self.extra_params,
}
return params
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts = [
ChatMlflow._convert_message_to_dict(message) for message in messages
]
data: Dict[str, Any] = {
"messages": message_dicts,
"temperature": self.temperature,
"n": self.n,
"stop": stop or self.stop,
"max_tokens": self.max_tokens,
**self.extra_params,
}
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return ChatMlflow._create_chat_result(resp)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)
@property
def _identifying_params(self) -> Dict[str, Any]:
return self._default_params
def _get_invocation_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> Dict[str, Any]:
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
return {
**self._default_params,
**super()._get_invocation_params(stop=stop, **kwargs),
}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "mlflow-chat"
@staticmethod
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
content = _dict["content"]
if role == "user":
return HumanMessage(content=content)
elif role == "assistant":
return AIMessage(content=content)
elif role == "system":
return SystemMessage(content=content)
else:
return ChatMessage(content=content, role=role)
@staticmethod
def _raise_functions_not_supported() -> None:
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
raise ValueError(
"Function messages are not supported by Databricks. Please"
" create a feature request at https://github.com/mlflow/mlflow/issues."
)
else:
raise ValueError(f"Got unknown message type: {message}")
if "function_call" in message.additional_kwargs:
ChatMlflow._raise_functions_not_supported()
if message.additional_kwargs:
logger.warning(
"Additional message arguments are unsupported by Databricks"
" and will be ignored: %s",
message.additional_kwargs,
)
return message_dict
@staticmethod
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
generations = []
for choice in response["choices"]:
message = ChatMlflow._convert_dict_to_message(choice["message"])
usage = choice.get("usage", {})
gen = ChatGeneration(
message=message,
generation_info=usage,
)
generations.append(gen)
usage = response.get("usage", {})
return ChatResult(generations=generations, llm_output=usage)

View File

@ -1,5 +1,6 @@
import asyncio
import logging
import warnings
from functools import partial
from typing import Any, Dict, List, Mapping, Optional
@ -59,6 +60,11 @@ class ChatMLflowAIGateway(BaseChatModel):
"""
def __init__(self, **kwargs: Any):
warnings.warn(
"`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or "
"`ChatDatabricks` instead.",
DeprecationWarning,
)
try:
import mlflow.gateway
except ImportError as e:

View File

@ -26,6 +26,7 @@ from langchain.embeddings.cache import CacheBackedEmbeddings
from langchain.embeddings.clarifai import ClarifaiEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.embeddings.dashscope import DashScopeEmbeddings
from langchain.embeddings.databricks import DatabricksEmbeddings
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
from langchain.embeddings.edenai import EdenAiEmbeddings
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
@ -50,6 +51,7 @@ from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
from langchain.embeddings.localai import LocalAIEmbeddings
from langchain.embeddings.minimax import MiniMaxEmbeddings
from langchain.embeddings.mlflow import MlflowEmbeddings
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
@ -78,6 +80,7 @@ __all__ = [
"CacheBackedEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",
"ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings",
@ -87,6 +90,7 @@ __all__ = [
"JinaEmbeddings",
"LlamaCppEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowEmbeddings",
"MlflowAIGatewayEmbeddings",
"ModelScopeEmbeddings",
"TensorflowHubEmbeddings",

View File

@ -0,0 +1,45 @@
from __future__ import annotations
from typing import Iterator, List
from urllib.parse import urlparse
from langchain.embeddings.mlflow import MlflowEmbeddings
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]
class DatabricksEmbeddings(MlflowEmbeddings):
"""Wrapper around embeddings LLMs in Databricks.
To use, you should have the ``mlflow`` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
Example:
.. code-block:: python
from langchain.embeddings import DatabricksEmbeddings
embeddings = DatabricksEmbeddings(
target_uri="databricks",
endpoint="embeddings",
)
"""
target_uri: str = "databricks"
"""The target URI to use. Defaults to ``databricks``."""
@property
def _mlflow_extras(self) -> str:
return ""
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
if urlparse(self.target_uri).scheme != "databricks":
raise ValueError(
"Invalid target URI. The target URI must be a valid databricks URI."
)

View File

@ -0,0 +1,74 @@
from __future__ import annotations
from typing import Any, Iterator, List
from urllib.parse import urlparse
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
for i in range(0, len(texts), size):
yield texts[i : i + size]
class MlflowEmbeddings(Embeddings, BaseModel):
"""Wrapper around embeddings LLMs in MLflow.
To use, you should have the `mlflow[genai]` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
Example:
.. code-block:: python
from langchain.embeddings import MlflowEmbeddings
embeddings = MlflowEmbeddings(
target_uri="http://localhost:5000",
endpoint="embeddings",
)
"""
endpoint: str
"""The endpoint to use."""
target_uri: str
"""The target URI to use."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client
self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
"required dependencies."
) from e
@property
def _mlflow_extras(self) -> str:
return "[genai]"
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
allowed = ["http", "https", "databricks"]
if urlparse(self.target_uri).scheme not in allowed:
raise ValueError(
f"Invalid target URI: {self.target_uri}. "
f"The scheme must be one of {allowed}."
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings: List[List[float]] = []
for txt in _chunk(texts, 20):
resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt})
embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings
def embed_query(self, text: str) -> List[float]:
return self.embed_documents([text])[0]

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import warnings
from typing import Any, Iterator, List, Optional
from langchain_core.embeddings import Embeddings
@ -35,6 +36,11 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
"""The URI for the MLflow AI Gateway API."""
def __init__(self, **kwargs: Any):
warnings.warn(
"`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or "
"`DatabricksEmbeddings` instead.",
DeprecationWarning,
)
try:
import mlflow.gateway
except ImportError as e:

View File

@ -148,6 +148,12 @@ def _import_databricks() -> Any:
return Databricks
def _import_databricks_chat() -> Any:
from langchain.chat_models.databricks import ChatDatabricks
return ChatDatabricks
def _import_deepinfra() -> Any:
from langchain.llms.deepinfra import DeepInfra
@ -276,6 +282,18 @@ def _import_minimax() -> Any:
return Minimax
def _import_mlflow() -> Any:
from langchain.llms.mlflow import Mlflow
return Mlflow
def _import_mlflow_chat() -> Any:
from langchain.chat_models.mlflow import ChatMlflow
return ChatMlflow
def _import_mlflow_ai_gateway() -> Any:
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
@ -595,6 +613,8 @@ def __getattr__(name: str) -> Any:
return _import_manifest()
elif name == "Minimax":
return _import_minimax()
elif name == "Mlflow":
return _import_mlflow()
elif name == "MlflowAIGateway":
return _import_mlflow_ai_gateway()
elif name == "Modal":
@ -789,6 +809,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"ctransformers": _import_ctransformers,
"ctranslate2": _import_ctranslate2,
"databricks": _import_databricks,
"databricks-chat": _import_databricks_chat,
"deepinfra": _import_deepinfra,
"deepsparse": _import_deepsparse,
"edenai": _import_edenai,
@ -808,6 +829,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"llamacpp": _import_llamacpp,
"textgen": _import_textgen,
"minimax": _import_minimax,
"mlflow": _import_mlflow,
"mlflow-chat": _import_mlflow_chat,
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
"modal": _import_modal,
"mosaic": _import_mosaicml,

View File

@ -1,8 +1,11 @@
import os
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
@ -12,9 +15,6 @@ from langchain_core.pydantic_v1 import (
validator,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
__all__ = ["Databricks"]
@ -24,24 +24,66 @@ class _DatabricksClientBase(BaseModel, ABC):
api_url: str
api_token: str
def post_raw(self, request: Any) -> Any:
def request(self, method: str, url: str, request: Any) -> Any:
headers = {"Authorization": f"Bearer {self.api_token}"}
response = requests.post(self.api_url, headers=headers, json=request)
response = requests.request(
method=method, url=url, headers=headers, json=request
)
# TODO: error handling and automatic retries
if not response.ok:
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
return response.json()
def _get(self, url: str) -> Any:
return self.request("GET", url, None)
def _post(self, url: str, request: Any) -> Any:
return self.request("POST", url, request)
@abstractmethod
def post(self, request: Any) -> Any:
def post(
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
) -> Any:
...
def _transform_completions(response: Dict[str, Any]) -> str:
return response["choices"][0]["text"]
def _transform_chat(response: Dict[str, Any]) -> str:
return response["choices"][0]["message"]["content"]
class _DatabricksServingEndpointClient(_DatabricksClientBase):
"""An API client that talks to a Databricks serving endpoint."""
host: str
endpoint_name: str
databricks_uri: str
client: Any = None
external_or_foundation: bool = False
task: Optional[str] = None
def __init__(self, **data: Any):
super().__init__(**data)
try:
from mlflow.deployments import get_deploy_client
self.client = get_deploy_client(self.databricks_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
"Please install mlflow with `pip install mlflow`."
) from e
endpoint = self.client.get_endpoint(self.endpoint_name)
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
"external_model",
"foundation_model_api",
)
self.task = endpoint.get("task")
@root_validator(pre=True)
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@ -52,14 +94,30 @@ class _DatabricksServingEndpointClient(_DatabricksClientBase):
values["api_url"] = api_url
return values
def post(self, request: Any) -> Any:
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
wrapped_request = {"dataframe_records": [request]}
response = self.post_raw(wrapped_request)["predictions"]
# For a single-record query, the result is not a list.
if isinstance(response, list):
response = response[0]
return response
def post(
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
) -> Any:
if self.external_or_foundation:
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
if transform_output_fn:
return transform_output_fn(resp)
if self.task == "llm/v1/chat":
return _transform_chat(resp)
elif self.task == "llm/v1/completions":
return _transform_completions(resp)
return resp
else:
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
wrapped_request = {"dataframe_records": [request]}
response = self.client.predict(
endpoint=self.endpoint_name, inputs=wrapped_request
)
preds = response["predictions"]
# For a single-record query, the result is not a list.
pred = preds[0] if isinstance(preds, list) else preds
return transform_output_fn(pred) if transform_output_fn else pred
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
@ -79,8 +137,8 @@ class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
values["api_url"] = api_url
return values
def post(self, request: Any) -> Any:
return self.post_raw(request)
def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any:
return self._post(self.api_url, request)
def get_repl_context() -> Any:
@ -137,16 +195,19 @@ def get_default_api_token() -> str:
class Databricks(LLM):
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
It supports two endpoint types:
* **Serving endpoint** (recommended for both production and development).
We assume that an LLM was registered and deployed to a serving endpoint.
We assume that an LLM was deployed to a serving endpoint.
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
``cluster_driver_port``.
The expected model signature is:
If the underlying model is a model registered by MLflow, the expected model
signature is:
* inputs::
@ -155,6 +216,10 @@ class Databricks(LLM):
* outputs: ``[{"type": "string"}]``
If the underlying model is an external or foundation model, the response from the
endpoint is automatically transformed to the expected format unless
``transform_output_fn`` is provided.
* **Cluster driver proxy app** (recommended for interactive development).
One can load an LLM on a Databricks interactive cluster and start a local HTTP
server on the driver node to serve the model at ``/`` using HTTP POST method
@ -220,9 +285,14 @@ class Databricks(LLM):
We recommend the server using a port number between ``[3000, 8000]``.
"""
model_kwargs: Optional[Dict[str, Any]] = None
params: Optional[Dict[str, Any]] = None
"""Extra parameters to pass to the endpoint."""
model_kwargs: Optional[Dict[str, Any]] = None
"""
Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint.
"""
transform_input_fn: Optional[Callable] = None
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
request object that the endpoint accepts.
@ -233,6 +303,9 @@ class Databricks(LLM):
"""A function that transforms the output from the endpoint to the generated text.
"""
databricks_uri: str = "databricks"
"""The databricks URI. Only used when using a serving endpoint."""
_client: _DatabricksClientBase = PrivateAttr()
class Config:
@ -283,11 +356,19 @@ class Databricks(LLM):
def __init__(self, **data: Any):
super().__init__(**data)
if self.model_kwargs is not None and self.params is not None:
raise ValueError("Cannot set both model_kwargs and params.")
elif self.model_kwargs is not None:
warnings.warn(
"model_kwargs is deprecated. Please use params instead.",
DeprecationWarning,
)
if self.endpoint_name:
self._client = _DatabricksServingEndpointClient(
host=self.host,
api_token=self.api_token,
endpoint_name=self.endpoint_name,
databricks_uri=self.databricks_uri,
)
elif self.cluster_id and self.cluster_driver_port:
self._client = _DatabricksClusterDriverProxyClient(
@ -301,6 +382,31 @@ class Databricks(LLM):
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
)
@property
def _params(self) -> Optional[Dict[str, Any]]:
return self.model_kwargs or self.params
@property
def _default_params(self) -> Dict[str, Any]:
"""Return default params."""
return {
"host": self.host,
# "api_token": self.api_token, # Never save the token
"endpoint_name": self.endpoint_name,
"cluster_id": self.cluster_id,
"cluster_driver_port": self.cluster_driver_port,
"databricks_uri": self.databricks_uri,
"model_kwargs": self.model_kwargs,
"params": self.params,
# TODO: Support saving transform_input_fn and transform_output_fn
# "transform_input_fn": self.transform_input_fn,
# "transform_output_fn": self.transform_output_fn,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
@ -319,8 +425,8 @@ class Databricks(LLM):
request = {"prompt": prompt, "stop": stop}
request.update(kwargs)
if self.model_kwargs:
request.update(self.model_kwargs)
if self._params:
request.update(self._params)
if self.transform_input_fn:
request = self.transform_input_fn(**request)

View File

@ -0,0 +1,104 @@
from __future__ import annotations
from typing import Any, Dict, List, Mapping, Optional
from urllib.parse import urlparse
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr
# Ignoring type because below is valid pydantic code
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
"""Parameters for MLflow"""
temperature: float = 0.0
n: int = 1
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
class Mlflow(LLM):
"""Wrapper around completions LLMs in MLflow.
To use, you should have the `mlflow[genai]` python package installed.
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
Example:
.. code-block:: python
from langchain.llms import Mlflow
completions = Mlflow(
target_uri="http://localhost:5000",
endpoint="test",
params={"temperature": 0.1}
)
"""
endpoint: str
"""The endpoint to use."""
target_uri: str
"""The target URI to use."""
params: Optional[Params] = None
"""Extra parameters such as `temperature`."""
_client: Any = PrivateAttr()
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._validate_uri()
try:
from mlflow.deployments import get_deploy_client
self._client = get_deploy_client(self.target_uri)
except ImportError as e:
raise ImportError(
"Failed to create the client. "
"Please run `pip install mlflow[genai]` to install "
"required dependencies."
) from e
def _validate_uri(self) -> None:
if self.target_uri == "databricks":
return
allowed = ["http", "https", "databricks"]
if urlparse(self.target_uri).scheme not in allowed:
raise ValueError(
f"Invalid target URI: {self.target_uri}. "
f"The scheme must be one of {allowed}."
)
@property
def _default_params(self) -> Dict[str, Any]:
params: Dict[str, Any] = {
"target_uri": self.target_uri,
"endpoint": self.endpoint,
}
if self.params:
params["params"] = self.params.dict()
return params
@property
def _identifying_params(self) -> Mapping[str, Any]:
return self._default_params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
data: Dict[str, Any] = {
"prompt": prompt,
**(self.params.dict() if self.params else {}),
}
if s := (stop or (self.params.stop if self.params else None)):
data["stop"] = s
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
return resp["choices"][0]["text"]
@property
def _llm_type(self) -> str:
return "mlflow"

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Mapping, Optional
from langchain_core.pydantic_v1 import BaseModel, Extra
@ -46,6 +47,10 @@ class MlflowAIGateway(LLM):
params: Optional[Params] = None
def __init__(self, **kwargs: Any):
warnings.warn(
"`MlflowAIGateway` is deprecated. Use `Mlflow` or `Databricks` instead.",
DeprecationWarning,
)
try:
import mlflow.gateway
except ImportError as e:

View File

@ -9,7 +9,9 @@ EXPECTED_ALL = [
"ChatEverlyAI",
"ChatAnthropic",
"ChatCohere",
"ChatDatabricks",
"ChatGooglePalm",
"ChatMlflow",
"ChatMLflowAIGateway",
"ChatOllama",
"ChatVertexAI",

View File

@ -6,6 +6,7 @@ EXPECTED_ALL = [
"CacheBackedEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",
"ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings",
@ -16,6 +17,7 @@ EXPECTED_ALL = [
"LlamaCppEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowAIGatewayEmbeddings",
"MlflowEmbeddings",
"ModelScopeEmbeddings",
"TensorflowHubEmbeddings",
"SagemakerEndpointEmbeddings",