mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699)
## Description Related to https://github.com/mlflow/mlflow/pull/10420. MLflow AI gateway will be deprecated and replaced by the `mlflow.deployments` module. Happy to split this PR if it's too large. ``` pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain ``` ## Dependencies Install mlflow from https://github.com/mlflow/mlflow/pull/10420: ``` pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge ``` ## Testing plan The following code works fine on local and databricks: <details><summary>Click</summary> <p> ```python """ Setup ----- mlflow deployments start-server --config-path examples/gateway/openai/config.yaml databricks secrets create-scope <scope> databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY Run --- python /path/to/this/file.py secrets/<scope>/openai-api-key """ from langchain.chat_models import ChatMlflow, ChatDatabricks from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings from langchain.llms import Databricks, Mlflow from langchain.schema.messages import HumanMessage from langchain.chains.loading import load_chain from mlflow.deployments import get_deploy_client import uuid import sys import tempfile from langchain.chains import LLMChain from langchain.prompts import PromptTemplate ############################### # MLflow ############################### chat = ChatMlflow( target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings") print(embeddings.embed_query("hello")[:3]) print(embeddings.embed_documents(["hello", "world"])[0][:3]) llm = Mlflow( target_uri="http://127.0.0.1:5000", endpoint="completions", params={"temperature": 0.1}, ) print(llm("I am")) llm_chain = LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["adjective"], template="Tell me a {adjective} joke", ), ) print(llm_chain.run(adjective="funny")) # serialization/deserialization with tempfile.TemporaryDirectory() as tmpdir: print(tmpdir) path = f"{tmpdir}/llm.yaml" llm_chain.save(path) loaded_chain = load_chain(path) print(loaded_chain("funny")) ############################### # Databricks ############################### secret = sys.argv[1] client = get_deploy_client("databricks") # External - chat name = f"chat-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "gpt-4", "provider": "openai", "task": "llm/v1/chat", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: chat = ChatDatabricks( target_uri="databricks", endpoint=name, params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) finally: client.delete_endpoint(endpoint=name) # External - embeddings name = f"embeddings-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "text-embedding-ada-002", "provider": "openai", "task": "llm/v1/embeddings", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name) print(embeddings.embed_query("hello")[:3]) print(embeddings.embed_documents(["hello", "world"])[0][:3]) finally: client.delete_endpoint(endpoint=name) # External - completions name = f"completions-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "gpt-3.5-turbo-instruct", "provider": "openai", "task": "llm/v1/completions", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: llm = Databricks( endpoint_name=name, model_kwargs={"temperature": 0.1}, ) print(llm("I am")) finally: client.delete_endpoint(endpoint=name) # Foundation model - chat chat = ChatDatabricks( endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) # Foundation model - embeddings embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") print(embeddings.embed_query("hello")[:3]) # Foundation model - completions llm = Databricks( endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1} ) print(llm("hello")) llm_chain = LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["adjective"], template="Tell me a {adjective} joke", ), ) print(llm_chain.run(adjective="funny")) # serialization/deserialization with tempfile.TemporaryDirectory() as tmpdir: print(tmpdir) path = f"{tmpdir}/llm.yaml" llm_chain.save(path) loaded_chain = load_chain(path) print(loaded_chain("funny")) ``` Output: ``` content='Hello! How can I assist you today?' [-0.025058426, -0.01938856, -0.027781019] [-0.025058426, -0.01938856, -0.027781019] sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context? Sure, here's a classic one for you: Why don't scientists trust atoms? Because they make up everything! /var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad {'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"} content='Hello! How can I assist you today?' [-0.025058426, -0.01938856, -0.027781019] [-0.025058426, -0.01938856, -0.027781019] a 23 year old female and I am currently studying for my master's degree content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?" [0.051055908203125, 0.007221221923828125, 0.003879547119140625] [0.051055908203125, 0.007221221923828125, 0.003879547119140625] hello back Well, I don't really know many jokes, but I do know this funny story... /var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex {'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."} ``` </p> </details> The existing workflow doesn't break: <details><summary>click</summary> <p> ```python import uuid import mlflow from mlflow.models import ModelSignature from mlflow.types.schema import ColSpec, Schema class MyModel(mlflow.pyfunc.PythonModel): def predict(self, context, model_input): return str(uuid.uuid4()) with mlflow.start_run(): mlflow.pyfunc.log_model( "model", python_model=MyModel(), pip_requirements=["mlflow==2.8.1", "cloudpickle<3"], signature=ModelSignature( inputs=Schema( [ ColSpec("string", "prompt"), ColSpec("string", "stop"), ] ), outputs=Schema( [ ColSpec(name=None, type="string"), ] ), ), registered_model_name=f"lang-{uuid.uuid4()}", ) # Manually create a serving endpoint with the registered model and run from langchain.llms import Databricks llm = Databricks(endpoint_name="<name>") llm("hello") # 9d0b2491-3d13-487c-bc02-1287f06ecae7 ``` </p> </details> ## Follow-up tasks (This PR is too large. I'll file a separate one for follow-up tasks.) - Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and `docs/docs/integrations/providers/databricks.md`. --------- Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
dc31714ec5
commit
0d08a692a3
@ -24,6 +24,7 @@ from langchain.chat_models.baichuan import ChatBaichuan
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.chat_models.bedrock import BedrockChat
|
||||
from langchain.chat_models.cohere import ChatCohere
|
||||
from langchain.chat_models.databricks import ChatDatabricks
|
||||
from langchain.chat_models.ernie import ErnieBotChat
|
||||
from langchain.chat_models.everlyai import ChatEverlyAI
|
||||
from langchain.chat_models.fake import FakeListChatModel
|
||||
@ -37,6 +38,7 @@ from langchain.chat_models.jinachat import JinaChat
|
||||
from langchain.chat_models.konko import ChatKonko
|
||||
from langchain.chat_models.litellm import ChatLiteLLM
|
||||
from langchain.chat_models.minimax import MiniMaxChat
|
||||
from langchain.chat_models.mlflow import ChatMlflow
|
||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||
from langchain.chat_models.ollama import ChatOllama
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@ -52,10 +54,12 @@ __all__ = [
|
||||
"AzureChatOpenAI",
|
||||
"FakeListChatModel",
|
||||
"PromptLayerChatOpenAI",
|
||||
"ChatDatabricks",
|
||||
"ChatEverlyAI",
|
||||
"ChatAnthropic",
|
||||
"ChatCohere",
|
||||
"ChatGooglePalm",
|
||||
"ChatMlflow",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatOllama",
|
||||
"ChatVertexAI",
|
||||
|
46
libs/langchain/langchain/chat_models/databricks.py
Normal file
46
libs/langchain/langchain/chat_models/databricks.py
Normal file
@ -0,0 +1,46 @@
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain.chat_models.mlflow import ChatMlflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatDatabricks(ChatMlflow):
|
||||
"""`Databricks` chat models API.
|
||||
|
||||
To use, you should have the ``mlflow`` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatDatabricks
|
||||
|
||||
chat = ChatDatabricks(
|
||||
target_uri="databricks",
|
||||
endpoint="chat",
|
||||
temperature-0.1,
|
||||
)
|
||||
"""
|
||||
|
||||
target_uri: str = "databricks"
|
||||
"""The target URI to use. Defaults to ``databricks``."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "databricks-chat"
|
||||
|
||||
@property
|
||||
def _mlflow_extras(self) -> str:
|
||||
return ""
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
|
||||
if urlparse(self.target_uri).scheme != "databricks":
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
217
libs/langchain/langchain/chat_models/mlflow.py
Normal file
217
libs/langchain/langchain/chat_models/mlflow.py
Normal file
@ -0,0 +1,217 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import (
|
||||
Field,
|
||||
PrivateAttr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatMlflow(BaseChatModel):
|
||||
"""`MLflow` chat models API.
|
||||
|
||||
To use, you should have the `mlflow[genai]` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import ChatMlflow
|
||||
|
||||
chat = ChatMlflow(
|
||||
target_uri="http://localhost:5000",
|
||||
endpoint="chat",
|
||||
temperature-0.1,
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The endpoint to use."""
|
||||
target_uri: str
|
||||
"""The target URI to use."""
|
||||
temperature: float = 0.0
|
||||
"""The sampling temperature."""
|
||||
n: int = 1
|
||||
"""The number of completion choices to generate."""
|
||||
stop: Optional[List[str]] = None
|
||||
"""The stop sequence."""
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
extra_params: dict = Field(default_factory=dict)
|
||||
"""Any extra parameters to pass to the endpoint."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._validate_uri()
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client
|
||||
|
||||
self._client = get_deploy_client(self.target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
|
||||
"required dependencies."
|
||||
) from e
|
||||
|
||||
@property
|
||||
def _mlflow_extras(self) -> str:
|
||||
return "[genai]"
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
allowed = ["http", "https", "databricks"]
|
||||
if urlparse(self.target_uri).scheme not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid target URI: {self.target_uri}. "
|
||||
f"The scheme must be one of {allowed}."
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"target_uri": self.target_uri,
|
||||
"endpoint": self.endpoint,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"extra_params": self.extra_params,
|
||||
}
|
||||
return params
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [
|
||||
ChatMlflow._convert_message_to_dict(message) for message in messages
|
||||
]
|
||||
data: Dict[str, Any] = {
|
||||
"messages": message_dicts,
|
||||
"temperature": self.temperature,
|
||||
"n": self.n,
|
||||
"stop": stop or self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
**self.extra_params,
|
||||
}
|
||||
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return ChatMlflow._create_chat_result(resp)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
def _get_invocation_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||
return {
|
||||
**self._default_params,
|
||||
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "mlflow-chat"
|
||||
|
||||
@staticmethod
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
content = _dict["content"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=content)
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=content)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=content)
|
||||
else:
|
||||
return ChatMessage(content=content, role=role)
|
||||
|
||||
@staticmethod
|
||||
def _raise_functions_not_supported() -> None:
|
||||
raise ValueError(
|
||||
"Function messages are not supported by Databricks. Please"
|
||||
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
raise ValueError(
|
||||
"Function messages are not supported by Databricks. Please"
|
||||
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown message type: {message}")
|
||||
|
||||
if "function_call" in message.additional_kwargs:
|
||||
ChatMlflow._raise_functions_not_supported()
|
||||
if message.additional_kwargs:
|
||||
logger.warning(
|
||||
"Additional message arguments are unsupported by Databricks"
|
||||
" and will be ignored: %s",
|
||||
message.additional_kwargs,
|
||||
)
|
||||
return message_dict
|
||||
|
||||
@staticmethod
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for choice in response["choices"]:
|
||||
message = ChatMlflow._convert_dict_to_message(choice["message"])
|
||||
usage = choice.get("usage", {})
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=usage,
|
||||
)
|
||||
generations.append(gen)
|
||||
|
||||
usage = response.get("usage", {})
|
||||
return ChatResult(generations=generations, llm_output=usage)
|
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
@ -59,6 +60,11 @@ class ChatMLflowAIGateway(BaseChatModel):
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
warnings.warn(
|
||||
"`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or "
|
||||
"`ChatDatabricks` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
|
@ -26,6 +26,7 @@ from langchain.embeddings.cache import CacheBackedEmbeddings
|
||||
from langchain.embeddings.clarifai import ClarifaiEmbeddings
|
||||
from langchain.embeddings.cohere import CohereEmbeddings
|
||||
from langchain.embeddings.dashscope import DashScopeEmbeddings
|
||||
from langchain.embeddings.databricks import DatabricksEmbeddings
|
||||
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||
@ -50,6 +51,7 @@ from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain.embeddings.localai import LocalAIEmbeddings
|
||||
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
||||
from langchain.embeddings.mlflow import MlflowEmbeddings
|
||||
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
||||
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
|
||||
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||
@ -78,6 +80,7 @@ __all__ = [
|
||||
"CacheBackedEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
@ -87,6 +90,7 @@ __all__ = [
|
||||
"JinaEmbeddings",
|
||||
"LlamaCppEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowEmbeddings",
|
||||
"MlflowAIGatewayEmbeddings",
|
||||
"ModelScopeEmbeddings",
|
||||
"TensorflowHubEmbeddings",
|
||||
|
45
libs/langchain/langchain/embeddings/databricks.py
Normal file
45
libs/langchain/langchain/embeddings/databricks.py
Normal file
@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterator, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain.embeddings.mlflow import MlflowEmbeddings
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
||||
|
||||
|
||||
class DatabricksEmbeddings(MlflowEmbeddings):
|
||||
"""Wrapper around embeddings LLMs in Databricks.
|
||||
|
||||
To use, you should have the ``mlflow`` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import DatabricksEmbeddings
|
||||
|
||||
embeddings = DatabricksEmbeddings(
|
||||
target_uri="databricks",
|
||||
endpoint="embeddings",
|
||||
)
|
||||
"""
|
||||
|
||||
target_uri: str = "databricks"
|
||||
"""The target URI to use. Defaults to ``databricks``."""
|
||||
|
||||
@property
|
||||
def _mlflow_extras(self) -> str:
|
||||
return ""
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
|
||||
if urlparse(self.target_uri).scheme != "databricks":
|
||||
raise ValueError(
|
||||
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||
)
|
74
libs/langchain/langchain/embeddings/mlflow.py
Normal file
74
libs/langchain/langchain/embeddings/mlflow.py
Normal file
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||
for i in range(0, len(texts), size):
|
||||
yield texts[i : i + size]
|
||||
|
||||
|
||||
class MlflowEmbeddings(Embeddings, BaseModel):
|
||||
"""Wrapper around embeddings LLMs in MLflow.
|
||||
|
||||
To use, you should have the `mlflow[genai]` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.embeddings import MlflowEmbeddings
|
||||
|
||||
embeddings = MlflowEmbeddings(
|
||||
target_uri="http://localhost:5000",
|
||||
endpoint="embeddings",
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The endpoint to use."""
|
||||
target_uri: str
|
||||
"""The target URI to use."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._validate_uri()
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client
|
||||
|
||||
self._client = get_deploy_client(self.target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
|
||||
"required dependencies."
|
||||
) from e
|
||||
|
||||
@property
|
||||
def _mlflow_extras(self) -> str:
|
||||
return "[genai]"
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
allowed = ["http", "https", "databricks"]
|
||||
if urlparse(self.target_uri).scheme not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid target URI: {self.target_uri}. "
|
||||
f"The scheme must be one of {allowed}."
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings: List[List[float]] = []
|
||||
for txt in _chunk(texts, 20):
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt})
|
||||
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self.embed_documents([text])[0]
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Iterator, List, Optional
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
@ -35,6 +36,11 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
||||
"""The URI for the MLflow AI Gateway API."""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
warnings.warn(
|
||||
"`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or "
|
||||
"`DatabricksEmbeddings` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
|
@ -148,6 +148,12 @@ def _import_databricks() -> Any:
|
||||
return Databricks
|
||||
|
||||
|
||||
def _import_databricks_chat() -> Any:
|
||||
from langchain.chat_models.databricks import ChatDatabricks
|
||||
|
||||
return ChatDatabricks
|
||||
|
||||
|
||||
def _import_deepinfra() -> Any:
|
||||
from langchain.llms.deepinfra import DeepInfra
|
||||
|
||||
@ -276,6 +282,18 @@ def _import_minimax() -> Any:
|
||||
return Minimax
|
||||
|
||||
|
||||
def _import_mlflow() -> Any:
|
||||
from langchain.llms.mlflow import Mlflow
|
||||
|
||||
return Mlflow
|
||||
|
||||
|
||||
def _import_mlflow_chat() -> Any:
|
||||
from langchain.chat_models.mlflow import ChatMlflow
|
||||
|
||||
return ChatMlflow
|
||||
|
||||
|
||||
def _import_mlflow_ai_gateway() -> Any:
|
||||
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
|
||||
|
||||
@ -595,6 +613,8 @@ def __getattr__(name: str) -> Any:
|
||||
return _import_manifest()
|
||||
elif name == "Minimax":
|
||||
return _import_minimax()
|
||||
elif name == "Mlflow":
|
||||
return _import_mlflow()
|
||||
elif name == "MlflowAIGateway":
|
||||
return _import_mlflow_ai_gateway()
|
||||
elif name == "Modal":
|
||||
@ -789,6 +809,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"ctransformers": _import_ctransformers,
|
||||
"ctranslate2": _import_ctranslate2,
|
||||
"databricks": _import_databricks,
|
||||
"databricks-chat": _import_databricks_chat,
|
||||
"deepinfra": _import_deepinfra,
|
||||
"deepsparse": _import_deepsparse,
|
||||
"edenai": _import_edenai,
|
||||
@ -808,6 +829,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"llamacpp": _import_llamacpp,
|
||||
"textgen": _import_textgen,
|
||||
"minimax": _import_minimax,
|
||||
"mlflow": _import_mlflow,
|
||||
"mlflow-chat": _import_mlflow_chat,
|
||||
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
|
||||
"modal": _import_modal,
|
||||
"mosaic": _import_mosaicml,
|
||||
|
@ -1,8 +1,11 @@
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
@ -12,9 +15,6 @@ from langchain_core.pydantic_v1 import (
|
||||
validator,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
__all__ = ["Databricks"]
|
||||
|
||||
|
||||
@ -24,24 +24,66 @@ class _DatabricksClientBase(BaseModel, ABC):
|
||||
api_url: str
|
||||
api_token: str
|
||||
|
||||
def post_raw(self, request: Any) -> Any:
|
||||
def request(self, method: str, url: str, request: Any) -> Any:
|
||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||
response = requests.post(self.api_url, headers=headers, json=request)
|
||||
response = requests.request(
|
||||
method=method, url=url, headers=headers, json=request
|
||||
)
|
||||
# TODO: error handling and automatic retries
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def _get(self, url: str) -> Any:
|
||||
return self.request("GET", url, None)
|
||||
|
||||
def _post(self, url: str, request: Any) -> Any:
|
||||
return self.request("POST", url, request)
|
||||
|
||||
@abstractmethod
|
||||
def post(self, request: Any) -> Any:
|
||||
def post(
|
||||
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
def _transform_completions(response: Dict[str, Any]) -> str:
|
||||
return response["choices"][0]["text"]
|
||||
|
||||
|
||||
def _transform_chat(response: Dict[str, Any]) -> str:
|
||||
return response["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||
"""An API client that talks to a Databricks serving endpoint."""
|
||||
|
||||
host: str
|
||||
endpoint_name: str
|
||||
databricks_uri: str
|
||||
client: Any = None
|
||||
external_or_foundation: bool = False
|
||||
task: Optional[str] = None
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client
|
||||
|
||||
self.client = get_deploy_client(self.databricks_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
"Please install mlflow with `pip install mlflow`."
|
||||
) from e
|
||||
|
||||
endpoint = self.client.get_endpoint(self.endpoint_name)
|
||||
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
|
||||
"external_model",
|
||||
"foundation_model_api",
|
||||
)
|
||||
self.task = endpoint.get("task")
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@ -52,14 +94,30 @@ class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, request: Any) -> Any:
|
||||
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
||||
wrapped_request = {"dataframe_records": [request]}
|
||||
response = self.post_raw(wrapped_request)["predictions"]
|
||||
# For a single-record query, the result is not a list.
|
||||
if isinstance(response, list):
|
||||
response = response[0]
|
||||
return response
|
||||
def post(
|
||||
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
||||
) -> Any:
|
||||
if self.external_or_foundation:
|
||||
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
|
||||
if transform_output_fn:
|
||||
return transform_output_fn(resp)
|
||||
|
||||
if self.task == "llm/v1/chat":
|
||||
return _transform_chat(resp)
|
||||
elif self.task == "llm/v1/completions":
|
||||
return _transform_completions(resp)
|
||||
|
||||
return resp
|
||||
else:
|
||||
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
||||
wrapped_request = {"dataframe_records": [request]}
|
||||
response = self.client.predict(
|
||||
endpoint=self.endpoint_name, inputs=wrapped_request
|
||||
)
|
||||
preds = response["predictions"]
|
||||
# For a single-record query, the result is not a list.
|
||||
pred = preds[0] if isinstance(preds, list) else preds
|
||||
return transform_output_fn(pred) if transform_output_fn else pred
|
||||
|
||||
|
||||
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||
@ -79,8 +137,8 @@ class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||
values["api_url"] = api_url
|
||||
return values
|
||||
|
||||
def post(self, request: Any) -> Any:
|
||||
return self.post_raw(request)
|
||||
def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any:
|
||||
return self._post(self.api_url, request)
|
||||
|
||||
|
||||
def get_repl_context() -> Any:
|
||||
@ -137,16 +195,19 @@ def get_default_api_token() -> str:
|
||||
|
||||
|
||||
class Databricks(LLM):
|
||||
|
||||
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
||||
|
||||
It supports two endpoint types:
|
||||
|
||||
* **Serving endpoint** (recommended for both production and development).
|
||||
We assume that an LLM was registered and deployed to a serving endpoint.
|
||||
We assume that an LLM was deployed to a serving endpoint.
|
||||
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
||||
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
||||
``cluster_driver_port``.
|
||||
The expected model signature is:
|
||||
|
||||
If the underlying model is a model registered by MLflow, the expected model
|
||||
signature is:
|
||||
|
||||
* inputs::
|
||||
|
||||
@ -155,6 +216,10 @@ class Databricks(LLM):
|
||||
|
||||
* outputs: ``[{"type": "string"}]``
|
||||
|
||||
If the underlying model is an external or foundation model, the response from the
|
||||
endpoint is automatically transformed to the expected format unless
|
||||
``transform_output_fn`` is provided.
|
||||
|
||||
* **Cluster driver proxy app** (recommended for interactive development).
|
||||
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
||||
server on the driver node to serve the model at ``/`` using HTTP POST method
|
||||
@ -220,9 +285,14 @@ class Databricks(LLM):
|
||||
We recommend the server using a port number between ``[3000, 8000]``.
|
||||
"""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
"""Extra parameters to pass to the endpoint."""
|
||||
|
||||
model_kwargs: Optional[Dict[str, Any]] = None
|
||||
"""
|
||||
Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint.
|
||||
"""
|
||||
|
||||
transform_input_fn: Optional[Callable] = None
|
||||
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
||||
request object that the endpoint accepts.
|
||||
@ -233,6 +303,9 @@ class Databricks(LLM):
|
||||
"""A function that transforms the output from the endpoint to the generated text.
|
||||
"""
|
||||
|
||||
databricks_uri: str = "databricks"
|
||||
"""The databricks URI. Only used when using a serving endpoint."""
|
||||
|
||||
_client: _DatabricksClientBase = PrivateAttr()
|
||||
|
||||
class Config:
|
||||
@ -283,11 +356,19 @@ class Databricks(LLM):
|
||||
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
if self.model_kwargs is not None and self.params is not None:
|
||||
raise ValueError("Cannot set both model_kwargs and params.")
|
||||
elif self.model_kwargs is not None:
|
||||
warnings.warn(
|
||||
"model_kwargs is deprecated. Please use params instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if self.endpoint_name:
|
||||
self._client = _DatabricksServingEndpointClient(
|
||||
host=self.host,
|
||||
api_token=self.api_token,
|
||||
endpoint_name=self.endpoint_name,
|
||||
databricks_uri=self.databricks_uri,
|
||||
)
|
||||
elif self.cluster_id and self.cluster_driver_port:
|
||||
self._client = _DatabricksClusterDriverProxyClient(
|
||||
@ -301,6 +382,31 @@ class Databricks(LLM):
|
||||
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
||||
)
|
||||
|
||||
@property
|
||||
def _params(self) -> Optional[Dict[str, Any]]:
|
||||
return self.model_kwargs or self.params
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Return default params."""
|
||||
return {
|
||||
"host": self.host,
|
||||
# "api_token": self.api_token, # Never save the token
|
||||
"endpoint_name": self.endpoint_name,
|
||||
"cluster_id": self.cluster_id,
|
||||
"cluster_driver_port": self.cluster_driver_port,
|
||||
"databricks_uri": self.databricks_uri,
|
||||
"model_kwargs": self.model_kwargs,
|
||||
"params": self.params,
|
||||
# TODO: Support saving transform_input_fn and transform_output_fn
|
||||
# "transform_input_fn": self.transform_input_fn,
|
||||
# "transform_output_fn": self.transform_output_fn,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
@ -319,8 +425,8 @@ class Databricks(LLM):
|
||||
|
||||
request = {"prompt": prompt, "stop": stop}
|
||||
request.update(kwargs)
|
||||
if self.model_kwargs:
|
||||
request.update(self.model_kwargs)
|
||||
if self._params:
|
||||
request.update(self._params)
|
||||
|
||||
if self.transform_input_fn:
|
||||
request = self.transform_input_fn(**request)
|
||||
|
104
libs/langchain/langchain/llms/mlflow.py
Normal file
104
libs/langchain/langchain/llms/mlflow.py
Normal file
@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr
|
||||
|
||||
|
||||
# Ignoring type because below is valid pydantic code
|
||||
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||
class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||
"""Parameters for MLflow"""
|
||||
|
||||
temperature: float = 0.0
|
||||
n: int = 1
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
|
||||
class Mlflow(LLM):
|
||||
"""Wrapper around completions LLMs in MLflow.
|
||||
|
||||
To use, you should have the `mlflow[genai]` python package installed.
|
||||
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import Mlflow
|
||||
|
||||
completions = Mlflow(
|
||||
target_uri="http://localhost:5000",
|
||||
endpoint="test",
|
||||
params={"temperature": 0.1}
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint: str
|
||||
"""The endpoint to use."""
|
||||
target_uri: str
|
||||
"""The target URI to use."""
|
||||
params: Optional[Params] = None
|
||||
"""Extra parameters such as `temperature`."""
|
||||
_client: Any = PrivateAttr()
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._validate_uri()
|
||||
try:
|
||||
from mlflow.deployments import get_deploy_client
|
||||
|
||||
self._client = get_deploy_client(self.target_uri)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Failed to create the client. "
|
||||
"Please run `pip install mlflow[genai]` to install "
|
||||
"required dependencies."
|
||||
) from e
|
||||
|
||||
def _validate_uri(self) -> None:
|
||||
if self.target_uri == "databricks":
|
||||
return
|
||||
allowed = ["http", "https", "databricks"]
|
||||
if urlparse(self.target_uri).scheme not in allowed:
|
||||
raise ValueError(
|
||||
f"Invalid target URI: {self.target_uri}. "
|
||||
f"The scheme must be one of {allowed}."
|
||||
)
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
params: Dict[str, Any] = {
|
||||
"target_uri": self.target_uri,
|
||||
"endpoint": self.endpoint,
|
||||
}
|
||||
if self.params:
|
||||
params["params"] = self.params.dict()
|
||||
return params
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
return self._default_params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
data: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
**(self.params.dict() if self.params else {}),
|
||||
}
|
||||
if s := (stop or (self.params.stop if self.params else None)):
|
||||
data["stop"] = s
|
||||
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||
return resp["choices"][0]["text"]
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "mlflow"
|
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
@ -46,6 +47,10 @@ class MlflowAIGateway(LLM):
|
||||
params: Optional[Params] = None
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
warnings.warn(
|
||||
"`MlflowAIGateway` is deprecated. Use `Mlflow` or `Databricks` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
try:
|
||||
import mlflow.gateway
|
||||
except ImportError as e:
|
||||
|
@ -9,7 +9,9 @@ EXPECTED_ALL = [
|
||||
"ChatEverlyAI",
|
||||
"ChatAnthropic",
|
||||
"ChatCohere",
|
||||
"ChatDatabricks",
|
||||
"ChatGooglePalm",
|
||||
"ChatMlflow",
|
||||
"ChatMLflowAIGateway",
|
||||
"ChatOllama",
|
||||
"ChatVertexAI",
|
||||
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
||||
"CacheBackedEmbeddings",
|
||||
"ClarifaiEmbeddings",
|
||||
"CohereEmbeddings",
|
||||
"DatabricksEmbeddings",
|
||||
"ElasticsearchEmbeddings",
|
||||
"FastEmbedEmbeddings",
|
||||
"HuggingFaceEmbeddings",
|
||||
@ -16,6 +17,7 @@ EXPECTED_ALL = [
|
||||
"LlamaCppEmbeddings",
|
||||
"HuggingFaceHubEmbeddings",
|
||||
"MlflowAIGatewayEmbeddings",
|
||||
"MlflowEmbeddings",
|
||||
"ModelScopeEmbeddings",
|
||||
"TensorflowHubEmbeddings",
|
||||
"SagemakerEndpointEmbeddings",
|
||||
|
Loading…
Reference in New Issue
Block a user