diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index b2db80b470f..3f510129b99 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -24,6 +24,7 @@ from langchain.chat_models.baichuan import ChatBaichuan from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint from langchain.chat_models.bedrock import BedrockChat from langchain.chat_models.cohere import ChatCohere +from langchain.chat_models.databricks import ChatDatabricks from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.everlyai import ChatEverlyAI from langchain.chat_models.fake import FakeListChatModel @@ -37,6 +38,7 @@ from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.konko import ChatKonko from langchain.chat_models.litellm import ChatLiteLLM from langchain.chat_models.minimax import MiniMaxChat +from langchain.chat_models.mlflow import ChatMlflow from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain.chat_models.ollama import ChatOllama from langchain.chat_models.openai import ChatOpenAI @@ -52,10 +54,12 @@ __all__ = [ "AzureChatOpenAI", "FakeListChatModel", "PromptLayerChatOpenAI", + "ChatDatabricks", "ChatEverlyAI", "ChatAnthropic", "ChatCohere", "ChatGooglePalm", + "ChatMlflow", "ChatMLflowAIGateway", "ChatOllama", "ChatVertexAI", diff --git a/libs/langchain/langchain/chat_models/databricks.py b/libs/langchain/langchain/chat_models/databricks.py new file mode 100644 index 00000000000..ea473d3e659 --- /dev/null +++ b/libs/langchain/langchain/chat_models/databricks.py @@ -0,0 +1,46 @@ +import logging +from urllib.parse import urlparse + +from langchain.chat_models.mlflow import ChatMlflow + +logger = logging.getLogger(__name__) + + +class ChatDatabricks(ChatMlflow): + """`Databricks` chat models API. + + To use, you should have the ``mlflow`` python package installed. + For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatDatabricks + + chat = ChatDatabricks( + target_uri="databricks", + endpoint="chat", + temperature-0.1, + ) + """ + + target_uri: str = "databricks" + """The target URI to use. Defaults to ``databricks``.""" + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "databricks-chat" + + @property + def _mlflow_extras(self) -> str: + return "" + + def _validate_uri(self) -> None: + if self.target_uri == "databricks": + return + + if urlparse(self.target_uri).scheme != "databricks": + raise ValueError( + "Invalid target URI. The target URI must be a valid databricks URI." + ) diff --git a/libs/langchain/langchain/chat_models/mlflow.py b/libs/langchain/langchain/chat_models/mlflow.py new file mode 100644 index 00000000000..4b40286c7eb --- /dev/null +++ b/libs/langchain/langchain/chat_models/mlflow.py @@ -0,0 +1,217 @@ +import asyncio +import logging +from functools import partial +from typing import Any, Dict, List, Mapping, Optional +from urllib.parse import urlparse + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + ChatMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.pydantic_v1 import ( + Field, + PrivateAttr, +) + +logger = logging.getLogger(__name__) + + +class ChatMlflow(BaseChatModel): + """`MLflow` chat models API. + + To use, you should have the `mlflow[genai]` python package installed. + For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatMlflow + + chat = ChatMlflow( + target_uri="http://localhost:5000", + endpoint="chat", + temperature-0.1, + ) + """ + + endpoint: str + """The endpoint to use.""" + target_uri: str + """The target URI to use.""" + temperature: float = 0.0 + """The sampling temperature.""" + n: int = 1 + """The number of completion choices to generate.""" + stop: Optional[List[str]] = None + """The stop sequence.""" + max_tokens: Optional[int] = None + """The maximum number of tokens to generate.""" + extra_params: dict = Field(default_factory=dict) + """Any extra parameters to pass to the endpoint.""" + _client: Any = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self._validate_uri() + try: + from mlflow.deployments import get_deploy_client + + self._client = get_deploy_client(self.target_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + f"Please run `pip install mlflow{self._mlflow_extras}` to install " + "required dependencies." + ) from e + + @property + def _mlflow_extras(self) -> str: + return "[genai]" + + def _validate_uri(self) -> None: + if self.target_uri == "databricks": + return + allowed = ["http", "https", "databricks"] + if urlparse(self.target_uri).scheme not in allowed: + raise ValueError( + f"Invalid target URI: {self.target_uri}. " + f"The scheme must be one of {allowed}." + ) + + @property + def _default_params(self) -> Dict[str, Any]: + params: Dict[str, Any] = { + "target_uri": self.target_uri, + "endpoint": self.endpoint, + "temperature": self.temperature, + "n": self.n, + "stop": self.stop, + "max_tokens": self.max_tokens, + "extra_params": self.extra_params, + } + return params + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message_dicts = [ + ChatMlflow._convert_message_to_dict(message) for message in messages + ] + data: Dict[str, Any] = { + "messages": message_dicts, + "temperature": self.temperature, + "n": self.n, + "stop": stop or self.stop, + "max_tokens": self.max_tokens, + **self.extra_params, + } + + resp = self._client.predict(endpoint=self.endpoint, inputs=data) + return ChatMlflow._create_chat_result(resp) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + func = partial( + self._generate, messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await asyncio.get_event_loop().run_in_executor(None, func) + + @property + def _identifying_params(self) -> Dict[str, Any]: + return self._default_params + + def _get_invocation_params( + self, stop: Optional[List[str]] = None, **kwargs: Any + ) -> Dict[str, Any]: + """Get the parameters used to invoke the model FOR THE CALLBACKS.""" + return { + **self._default_params, + **super()._get_invocation_params(stop=stop, **kwargs), + } + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "mlflow-chat" + + @staticmethod + def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["role"] + content = _dict["content"] + if role == "user": + return HumanMessage(content=content) + elif role == "assistant": + return AIMessage(content=content) + elif role == "system": + return SystemMessage(content=content) + else: + return ChatMessage(content=content, role=role) + + @staticmethod + def _raise_functions_not_supported() -> None: + raise ValueError( + "Function messages are not supported by Databricks. Please" + " create a feature request at https://github.com/mlflow/mlflow/issues." + ) + + @staticmethod + def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + raise ValueError( + "Function messages are not supported by Databricks. Please" + " create a feature request at https://github.com/mlflow/mlflow/issues." + ) + else: + raise ValueError(f"Got unknown message type: {message}") + + if "function_call" in message.additional_kwargs: + ChatMlflow._raise_functions_not_supported() + if message.additional_kwargs: + logger.warning( + "Additional message arguments are unsupported by Databricks" + " and will be ignored: %s", + message.additional_kwargs, + ) + return message_dict + + @staticmethod + def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: + generations = [] + for choice in response["choices"]: + message = ChatMlflow._convert_dict_to_message(choice["message"]) + usage = choice.get("usage", {}) + gen = ChatGeneration( + message=message, + generation_info=usage, + ) + generations.append(gen) + + usage = response.get("usage", {}) + return ChatResult(generations=generations, llm_output=usage) diff --git a/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py b/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py index 131eb8b49be..7704cefa8ee 100644 --- a/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py +++ b/libs/langchain/langchain/chat_models/mlflow_ai_gateway.py @@ -1,5 +1,6 @@ import asyncio import logging +import warnings from functools import partial from typing import Any, Dict, List, Mapping, Optional @@ -59,6 +60,11 @@ class ChatMLflowAIGateway(BaseChatModel): """ def __init__(self, **kwargs: Any): + warnings.warn( + "`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or " + "`ChatDatabricks` instead.", + DeprecationWarning, + ) try: import mlflow.gateway except ImportError as e: diff --git a/libs/langchain/langchain/embeddings/__init__.py b/libs/langchain/langchain/embeddings/__init__.py index dd573f8ebd1..8f288794256 100644 --- a/libs/langchain/langchain/embeddings/__init__.py +++ b/libs/langchain/langchain/embeddings/__init__.py @@ -26,6 +26,7 @@ from langchain.embeddings.cache import CacheBackedEmbeddings from langchain.embeddings.clarifai import ClarifaiEmbeddings from langchain.embeddings.cohere import CohereEmbeddings from langchain.embeddings.dashscope import DashScopeEmbeddings +from langchain.embeddings.databricks import DatabricksEmbeddings from langchain.embeddings.deepinfra import DeepInfraEmbeddings from langchain.embeddings.edenai import EdenAiEmbeddings from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings @@ -50,6 +51,7 @@ from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings from langchain.embeddings.llamacpp import LlamaCppEmbeddings from langchain.embeddings.localai import LocalAIEmbeddings from langchain.embeddings.minimax import MiniMaxEmbeddings +from langchain.embeddings.mlflow import MlflowEmbeddings from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings @@ -78,6 +80,7 @@ __all__ = [ "CacheBackedEmbeddings", "ClarifaiEmbeddings", "CohereEmbeddings", + "DatabricksEmbeddings", "ElasticsearchEmbeddings", "FastEmbedEmbeddings", "HuggingFaceEmbeddings", @@ -87,6 +90,7 @@ __all__ = [ "JinaEmbeddings", "LlamaCppEmbeddings", "HuggingFaceHubEmbeddings", + "MlflowEmbeddings", "MlflowAIGatewayEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", diff --git a/libs/langchain/langchain/embeddings/databricks.py b/libs/langchain/langchain/embeddings/databricks.py new file mode 100644 index 00000000000..6a1e29a86cc --- /dev/null +++ b/libs/langchain/langchain/embeddings/databricks.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import Iterator, List +from urllib.parse import urlparse + +from langchain.embeddings.mlflow import MlflowEmbeddings + + +def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: + for i in range(0, len(texts), size): + yield texts[i : i + size] + + +class DatabricksEmbeddings(MlflowEmbeddings): + """Wrapper around embeddings LLMs in Databricks. + + To use, you should have the ``mlflow`` python package installed. + For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html. + + Example: + .. code-block:: python + + from langchain.embeddings import DatabricksEmbeddings + + embeddings = DatabricksEmbeddings( + target_uri="databricks", + endpoint="embeddings", + ) + """ + + target_uri: str = "databricks" + """The target URI to use. Defaults to ``databricks``.""" + + @property + def _mlflow_extras(self) -> str: + return "" + + def _validate_uri(self) -> None: + if self.target_uri == "databricks": + return + + if urlparse(self.target_uri).scheme != "databricks": + raise ValueError( + "Invalid target URI. The target URI must be a valid databricks URI." + ) diff --git a/libs/langchain/langchain/embeddings/mlflow.py b/libs/langchain/langchain/embeddings/mlflow.py new file mode 100644 index 00000000000..42499f50a0f --- /dev/null +++ b/libs/langchain/langchain/embeddings/mlflow.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Iterator, List +from urllib.parse import urlparse + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, PrivateAttr + + +def _chunk(texts: List[str], size: int) -> Iterator[List[str]]: + for i in range(0, len(texts), size): + yield texts[i : i + size] + + +class MlflowEmbeddings(Embeddings, BaseModel): + """Wrapper around embeddings LLMs in MLflow. + + To use, you should have the `mlflow[genai]` python package installed. + For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html. + + Example: + .. code-block:: python + + from langchain.embeddings import MlflowEmbeddings + + embeddings = MlflowEmbeddings( + target_uri="http://localhost:5000", + endpoint="embeddings", + ) + """ + + endpoint: str + """The endpoint to use.""" + target_uri: str + """The target URI to use.""" + _client: Any = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self._validate_uri() + try: + from mlflow.deployments import get_deploy_client + + self._client = get_deploy_client(self.target_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + f"Please run `pip install mlflow{self._mlflow_extras}` to install " + "required dependencies." + ) from e + + @property + def _mlflow_extras(self) -> str: + return "[genai]" + + def _validate_uri(self) -> None: + if self.target_uri == "databricks": + return + allowed = ["http", "https", "databricks"] + if urlparse(self.target_uri).scheme not in allowed: + raise ValueError( + f"Invalid target URI: {self.target_uri}. " + f"The scheme must be one of {allowed}." + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + embeddings: List[List[float]] = [] + for txt in _chunk(texts, 20): + resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt}) + embeddings.extend(r["embedding"] for r in resp["data"]) + return embeddings + + def embed_query(self, text: str) -> List[float]: + return self.embed_documents([text])[0] diff --git a/libs/langchain/langchain/embeddings/mlflow_gateway.py b/libs/langchain/langchain/embeddings/mlflow_gateway.py index 56f5eee1bb4..8bbbc2e3842 100644 --- a/libs/langchain/langchain/embeddings/mlflow_gateway.py +++ b/libs/langchain/langchain/embeddings/mlflow_gateway.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Any, Iterator, List, Optional from langchain_core.embeddings import Embeddings @@ -35,6 +36,11 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel): """The URI for the MLflow AI Gateway API.""" def __init__(self, **kwargs: Any): + warnings.warn( + "`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or " + "`DatabricksEmbeddings` instead.", + DeprecationWarning, + ) try: import mlflow.gateway except ImportError as e: diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index 407849ee624..81b1764f13d 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -148,6 +148,12 @@ def _import_databricks() -> Any: return Databricks +def _import_databricks_chat() -> Any: + from langchain.chat_models.databricks import ChatDatabricks + + return ChatDatabricks + + def _import_deepinfra() -> Any: from langchain.llms.deepinfra import DeepInfra @@ -276,6 +282,18 @@ def _import_minimax() -> Any: return Minimax +def _import_mlflow() -> Any: + from langchain.llms.mlflow import Mlflow + + return Mlflow + + +def _import_mlflow_chat() -> Any: + from langchain.chat_models.mlflow import ChatMlflow + + return ChatMlflow + + def _import_mlflow_ai_gateway() -> Any: from langchain.llms.mlflow_ai_gateway import MlflowAIGateway @@ -595,6 +613,8 @@ def __getattr__(name: str) -> Any: return _import_manifest() elif name == "Minimax": return _import_minimax() + elif name == "Mlflow": + return _import_mlflow() elif name == "MlflowAIGateway": return _import_mlflow_ai_gateway() elif name == "Modal": @@ -789,6 +809,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "ctransformers": _import_ctransformers, "ctranslate2": _import_ctranslate2, "databricks": _import_databricks, + "databricks-chat": _import_databricks_chat, "deepinfra": _import_deepinfra, "deepsparse": _import_deepsparse, "edenai": _import_edenai, @@ -808,6 +829,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "llamacpp": _import_llamacpp, "textgen": _import_textgen, "minimax": _import_minimax, + "mlflow": _import_mlflow, + "mlflow-chat": _import_mlflow_chat, "mlflow-ai-gateway": _import_mlflow_ai_gateway, "modal": _import_modal, "mosaic": _import_mosaicml, diff --git a/libs/langchain/langchain/llms/databricks.py b/libs/langchain/langchain/llms/databricks.py index 94eaedceec8..55dff843917 100644 --- a/libs/langchain/langchain/llms/databricks.py +++ b/libs/langchain/langchain/llms/databricks.py @@ -1,8 +1,11 @@ import os +import warnings from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Mapping, Optional import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LLM from langchain_core.pydantic_v1 import ( BaseModel, Extra, @@ -12,9 +15,6 @@ from langchain_core.pydantic_v1 import ( validator, ) -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM - __all__ = ["Databricks"] @@ -24,24 +24,66 @@ class _DatabricksClientBase(BaseModel, ABC): api_url: str api_token: str - def post_raw(self, request: Any) -> Any: + def request(self, method: str, url: str, request: Any) -> Any: headers = {"Authorization": f"Bearer {self.api_token}"} - response = requests.post(self.api_url, headers=headers, json=request) + response = requests.request( + method=method, url=url, headers=headers, json=request + ) # TODO: error handling and automatic retries if not response.ok: raise ValueError(f"HTTP {response.status_code} error: {response.text}") return response.json() + def _get(self, url: str) -> Any: + return self.request("GET", url, None) + + def _post(self, url: str, request: Any) -> Any: + return self.request("POST", url, request) + @abstractmethod - def post(self, request: Any) -> Any: + def post( + self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None + ) -> Any: ... +def _transform_completions(response: Dict[str, Any]) -> str: + return response["choices"][0]["text"] + + +def _transform_chat(response: Dict[str, Any]) -> str: + return response["choices"][0]["message"]["content"] + + class _DatabricksServingEndpointClient(_DatabricksClientBase): """An API client that talks to a Databricks serving endpoint.""" host: str endpoint_name: str + databricks_uri: str + client: Any = None + external_or_foundation: bool = False + task: Optional[str] = None + + def __init__(self, **data: Any): + super().__init__(**data) + + try: + from mlflow.deployments import get_deploy_client + + self.client = get_deploy_client(self.databricks_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + "Please install mlflow with `pip install mlflow`." + ) from e + + endpoint = self.client.get_endpoint(self.endpoint_name) + self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in ( + "external_model", + "foundation_model_api", + ) + self.task = endpoint.get("task") @root_validator(pre=True) def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]: @@ -52,14 +94,30 @@ class _DatabricksServingEndpointClient(_DatabricksClientBase): values["api_url"] = api_url return values - def post(self, request: Any) -> Any: - # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html - wrapped_request = {"dataframe_records": [request]} - response = self.post_raw(wrapped_request)["predictions"] - # For a single-record query, the result is not a list. - if isinstance(response, list): - response = response[0] - return response + def post( + self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None + ) -> Any: + if self.external_or_foundation: + resp = self.client.predict(endpoint=self.endpoint_name, inputs=request) + if transform_output_fn: + return transform_output_fn(resp) + + if self.task == "llm/v1/chat": + return _transform_chat(resp) + elif self.task == "llm/v1/completions": + return _transform_completions(resp) + + return resp + else: + # See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html + wrapped_request = {"dataframe_records": [request]} + response = self.client.predict( + endpoint=self.endpoint_name, inputs=wrapped_request + ) + preds = response["predictions"] + # For a single-record query, the result is not a list. + pred = preds[0] if isinstance(preds, list) else preds + return transform_output_fn(pred) if transform_output_fn else pred class _DatabricksClusterDriverProxyClient(_DatabricksClientBase): @@ -79,8 +137,8 @@ class _DatabricksClusterDriverProxyClient(_DatabricksClientBase): values["api_url"] = api_url return values - def post(self, request: Any) -> Any: - return self.post_raw(request) + def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any: + return self._post(self.api_url, request) def get_repl_context() -> Any: @@ -137,16 +195,19 @@ def get_default_api_token() -> str: class Databricks(LLM): + """Databricks serving endpoint or a cluster driver proxy app for LLM. It supports two endpoint types: * **Serving endpoint** (recommended for both production and development). - We assume that an LLM was registered and deployed to a serving endpoint. + We assume that an LLM was deployed to a serving endpoint. To wrap it as an LLM you must have "Can Query" permission to the endpoint. Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and ``cluster_driver_port``. - The expected model signature is: + + If the underlying model is a model registered by MLflow, the expected model + signature is: * inputs:: @@ -155,6 +216,10 @@ class Databricks(LLM): * outputs: ``[{"type": "string"}]`` + If the underlying model is an external or foundation model, the response from the + endpoint is automatically transformed to the expected format unless + ``transform_output_fn`` is provided. + * **Cluster driver proxy app** (recommended for interactive development). One can load an LLM on a Databricks interactive cluster and start a local HTTP server on the driver node to serve the model at ``/`` using HTTP POST method @@ -220,9 +285,14 @@ class Databricks(LLM): We recommend the server using a port number between ``[3000, 8000]``. """ - model_kwargs: Optional[Dict[str, Any]] = None + params: Optional[Dict[str, Any]] = None """Extra parameters to pass to the endpoint.""" + model_kwargs: Optional[Dict[str, Any]] = None + """ + Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint. + """ + transform_input_fn: Optional[Callable] = None """A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible request object that the endpoint accepts. @@ -233,6 +303,9 @@ class Databricks(LLM): """A function that transforms the output from the endpoint to the generated text. """ + databricks_uri: str = "databricks" + """The databricks URI. Only used when using a serving endpoint.""" + _client: _DatabricksClientBase = PrivateAttr() class Config: @@ -283,11 +356,19 @@ class Databricks(LLM): def __init__(self, **data: Any): super().__init__(**data) + if self.model_kwargs is not None and self.params is not None: + raise ValueError("Cannot set both model_kwargs and params.") + elif self.model_kwargs is not None: + warnings.warn( + "model_kwargs is deprecated. Please use params instead.", + DeprecationWarning, + ) if self.endpoint_name: self._client = _DatabricksServingEndpointClient( host=self.host, api_token=self.api_token, endpoint_name=self.endpoint_name, + databricks_uri=self.databricks_uri, ) elif self.cluster_id and self.cluster_driver_port: self._client = _DatabricksClusterDriverProxyClient( @@ -301,6 +382,31 @@ class Databricks(LLM): "Must specify either endpoint_name or cluster_id/cluster_driver_port." ) + @property + def _params(self) -> Optional[Dict[str, Any]]: + return self.model_kwargs or self.params + + @property + def _default_params(self) -> Dict[str, Any]: + """Return default params.""" + return { + "host": self.host, + # "api_token": self.api_token, # Never save the token + "endpoint_name": self.endpoint_name, + "cluster_id": self.cluster_id, + "cluster_driver_port": self.cluster_driver_port, + "databricks_uri": self.databricks_uri, + "model_kwargs": self.model_kwargs, + "params": self.params, + # TODO: Support saving transform_input_fn and transform_output_fn + # "transform_input_fn": self.transform_input_fn, + # "transform_output_fn": self.transform_output_fn, + } + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return self._default_params + @property def _llm_type(self) -> str: """Return type of llm.""" @@ -319,8 +425,8 @@ class Databricks(LLM): request = {"prompt": prompt, "stop": stop} request.update(kwargs) - if self.model_kwargs: - request.update(self.model_kwargs) + if self._params: + request.update(self._params) if self.transform_input_fn: request = self.transform_input_fn(**request) diff --git a/libs/langchain/langchain/llms/mlflow.py b/libs/langchain/langchain/llms/mlflow.py new file mode 100644 index 00000000000..7f77fe12eff --- /dev/null +++ b/libs/langchain/langchain/llms/mlflow.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Mapping, Optional +from urllib.parse import urlparse + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LLM +from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr + + +# Ignoring type because below is valid pydantic code +# Unexpected keyword argument "extra" for "__init_subclass__" of "object" +class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg] + """Parameters for MLflow""" + + temperature: float = 0.0 + n: int = 1 + stop: Optional[List[str]] = None + max_tokens: Optional[int] = None + + +class Mlflow(LLM): + """Wrapper around completions LLMs in MLflow. + + To use, you should have the `mlflow[genai]` python package installed. + For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html. + + Example: + .. code-block:: python + + from langchain.llms import Mlflow + + completions = Mlflow( + target_uri="http://localhost:5000", + endpoint="test", + params={"temperature": 0.1} + ) + """ + + endpoint: str + """The endpoint to use.""" + target_uri: str + """The target URI to use.""" + params: Optional[Params] = None + """Extra parameters such as `temperature`.""" + _client: Any = PrivateAttr() + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self._validate_uri() + try: + from mlflow.deployments import get_deploy_client + + self._client = get_deploy_client(self.target_uri) + except ImportError as e: + raise ImportError( + "Failed to create the client. " + "Please run `pip install mlflow[genai]` to install " + "required dependencies." + ) from e + + def _validate_uri(self) -> None: + if self.target_uri == "databricks": + return + allowed = ["http", "https", "databricks"] + if urlparse(self.target_uri).scheme not in allowed: + raise ValueError( + f"Invalid target URI: {self.target_uri}. " + f"The scheme must be one of {allowed}." + ) + + @property + def _default_params(self) -> Dict[str, Any]: + params: Dict[str, Any] = { + "target_uri": self.target_uri, + "endpoint": self.endpoint, + } + if self.params: + params["params"] = self.params.dict() + return params + + @property + def _identifying_params(self) -> Mapping[str, Any]: + return self._default_params + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + data: Dict[str, Any] = { + "prompt": prompt, + **(self.params.dict() if self.params else {}), + } + if s := (stop or (self.params.stop if self.params else None)): + data["stop"] = s + resp = self._client.predict(endpoint=self.endpoint, inputs=data) + return resp["choices"][0]["text"] + + @property + def _llm_type(self) -> str: + return "mlflow" diff --git a/libs/langchain/langchain/llms/mlflow_ai_gateway.py b/libs/langchain/langchain/llms/mlflow_ai_gateway.py index 62b025d2155..c98529dd95f 100644 --- a/libs/langchain/langchain/llms/mlflow_ai_gateway.py +++ b/libs/langchain/langchain/llms/mlflow_ai_gateway.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from typing import Any, Dict, List, Mapping, Optional from langchain_core.pydantic_v1 import BaseModel, Extra @@ -46,6 +47,10 @@ class MlflowAIGateway(LLM): params: Optional[Params] = None def __init__(self, **kwargs: Any): + warnings.warn( + "`MlflowAIGateway` is deprecated. Use `Mlflow` or `Databricks` instead.", + DeprecationWarning, + ) try: import mlflow.gateway except ImportError as e: diff --git a/libs/langchain/tests/unit_tests/chat_models/test_imports.py b/libs/langchain/tests/unit_tests/chat_models/test_imports.py index 2adb8ec65a5..7afded1805f 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_imports.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_imports.py @@ -9,7 +9,9 @@ EXPECTED_ALL = [ "ChatEverlyAI", "ChatAnthropic", "ChatCohere", + "ChatDatabricks", "ChatGooglePalm", + "ChatMlflow", "ChatMLflowAIGateway", "ChatOllama", "ChatVertexAI", diff --git a/libs/langchain/tests/unit_tests/embeddings/test_imports.py b/libs/langchain/tests/unit_tests/embeddings/test_imports.py index 7a72b32ec43..9de69602dc6 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_imports.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_imports.py @@ -6,6 +6,7 @@ EXPECTED_ALL = [ "CacheBackedEmbeddings", "ClarifaiEmbeddings", "CohereEmbeddings", + "DatabricksEmbeddings", "ElasticsearchEmbeddings", "FastEmbedEmbeddings", "HuggingFaceEmbeddings", @@ -16,6 +17,7 @@ EXPECTED_ALL = [ "LlamaCppEmbeddings", "HuggingFaceHubEmbeddings", "MlflowAIGatewayEmbeddings", + "MlflowEmbeddings", "ModelScopeEmbeddings", "TensorflowHubEmbeddings", "SagemakerEndpointEmbeddings",