mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
langchain[minor]: Migrate mlflow and databricks classes to deployments APIs. (#13699)
## Description Related to https://github.com/mlflow/mlflow/pull/10420. MLflow AI gateway will be deprecated and replaced by the `mlflow.deployments` module. Happy to split this PR if it's too large. ``` pip install git+https://github.com/langchain-ai/langchain.git@refs/pull/13699/merge#subdirectory=libs/langchain ``` ## Dependencies Install mlflow from https://github.com/mlflow/mlflow/pull/10420: ``` pip install git+https://github.com/mlflow/mlflow.git@refs/pull/10420/merge ``` ## Testing plan The following code works fine on local and databricks: <details><summary>Click</summary> <p> ```python """ Setup ----- mlflow deployments start-server --config-path examples/gateway/openai/config.yaml databricks secrets create-scope <scope> databricks secrets put-secret <scope> openai-api-key --string-value $OPENAI_API_KEY Run --- python /path/to/this/file.py secrets/<scope>/openai-api-key """ from langchain.chat_models import ChatMlflow, ChatDatabricks from langchain.embeddings import MlflowEmbeddings, DatabricksEmbeddings from langchain.llms import Databricks, Mlflow from langchain.schema.messages import HumanMessage from langchain.chains.loading import load_chain from mlflow.deployments import get_deploy_client import uuid import sys import tempfile from langchain.chains import LLMChain from langchain.prompts import PromptTemplate ############################### # MLflow ############################### chat = ChatMlflow( target_uri="http://127.0.0.1:5000", endpoint="chat", params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) embeddings = MlflowEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings") print(embeddings.embed_query("hello")[:3]) print(embeddings.embed_documents(["hello", "world"])[0][:3]) llm = Mlflow( target_uri="http://127.0.0.1:5000", endpoint="completions", params={"temperature": 0.1}, ) print(llm("I am")) llm_chain = LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["adjective"], template="Tell me a {adjective} joke", ), ) print(llm_chain.run(adjective="funny")) # serialization/deserialization with tempfile.TemporaryDirectory() as tmpdir: print(tmpdir) path = f"{tmpdir}/llm.yaml" llm_chain.save(path) loaded_chain = load_chain(path) print(loaded_chain("funny")) ############################### # Databricks ############################### secret = sys.argv[1] client = get_deploy_client("databricks") # External - chat name = f"chat-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "gpt-4", "provider": "openai", "task": "llm/v1/chat", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: chat = ChatDatabricks( target_uri="databricks", endpoint=name, params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) finally: client.delete_endpoint(endpoint=name) # External - embeddings name = f"embeddings-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "text-embedding-ada-002", "provider": "openai", "task": "llm/v1/embeddings", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: embeddings = DatabricksEmbeddings(target_uri="databricks", endpoint=name) print(embeddings.embed_query("hello")[:3]) print(embeddings.embed_documents(["hello", "world"])[0][:3]) finally: client.delete_endpoint(endpoint=name) # External - completions name = f"completions-{uuid.uuid4()}" client.create_endpoint( name=name, config={ "served_entities": [ { "name": "test", "external_model": { "name": "gpt-3.5-turbo-instruct", "provider": "openai", "task": "llm/v1/completions", "openai_config": { "openai_api_key": "{{" + secret + "}}", }, }, } ], }, ) try: llm = Databricks( endpoint_name=name, model_kwargs={"temperature": 0.1}, ) print(llm("I am")) finally: client.delete_endpoint(endpoint=name) # Foundation model - chat chat = ChatDatabricks( endpoint="databricks-llama-2-70b-chat", params={"temperature": 0.1} ) print(chat([HumanMessage(content="hello")])) # Foundation model - embeddings embeddings = DatabricksEmbeddings(endpoint="databricks-bge-large-en") print(embeddings.embed_query("hello")[:3]) # Foundation model - completions llm = Databricks( endpoint_name="databricks-mpt-7b-instruct", model_kwargs={"temperature": 0.1} ) print(llm("hello")) llm_chain = LLMChain( llm=llm, prompt=PromptTemplate( input_variables=["adjective"], template="Tell me a {adjective} joke", ), ) print(llm_chain.run(adjective="funny")) # serialization/deserialization with tempfile.TemporaryDirectory() as tmpdir: print(tmpdir) path = f"{tmpdir}/llm.yaml" llm_chain.save(path) loaded_chain = load_chain(path) print(loaded_chain("funny")) ``` Output: ``` content='Hello! How can I assist you today?' [-0.025058426, -0.01938856, -0.027781019] [-0.025058426, -0.01938856, -0.027781019] sorry, but I cannot continue the sentence as it is incomplete. Can you please provide more information or context? Sure, here's a classic one for you: Why don't scientists trust atoms? Because they make up everything! /var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmpx_4no6ad {'adjective': 'funny', 'text': "Sure, here's a classic one for you:\n\nWhy don't scientists trust atoms?\n\nBecause they make up everything!"} content='Hello! How can I assist you today?' [-0.025058426, -0.01938856, -0.027781019] [-0.025058426, -0.01938856, -0.027781019] a 23 year old female and I am currently studying for my master's degree content="\nHello! It's nice to meet you. Is there something I can help you with or would you like to chat for a bit?" [0.051055908203125, 0.007221221923828125, 0.003879547119140625] [0.051055908203125, 0.007221221923828125, 0.003879547119140625] hello back Well, I don't really know many jokes, but I do know this funny story... /var/folders/dz/cd_nvlf14g9g__n3ph0d_0pm0000gp/T/tmp7_ds72ex {'adjective': 'funny', 'text': " Well, I don't really know many jokes, but I do know this funny story..."} ``` </p> </details> The existing workflow doesn't break: <details><summary>click</summary> <p> ```python import uuid import mlflow from mlflow.models import ModelSignature from mlflow.types.schema import ColSpec, Schema class MyModel(mlflow.pyfunc.PythonModel): def predict(self, context, model_input): return str(uuid.uuid4()) with mlflow.start_run(): mlflow.pyfunc.log_model( "model", python_model=MyModel(), pip_requirements=["mlflow==2.8.1", "cloudpickle<3"], signature=ModelSignature( inputs=Schema( [ ColSpec("string", "prompt"), ColSpec("string", "stop"), ] ), outputs=Schema( [ ColSpec(name=None, type="string"), ] ), ), registered_model_name=f"lang-{uuid.uuid4()}", ) # Manually create a serving endpoint with the registered model and run from langchain.llms import Databricks llm = Databricks(endpoint_name="<name>") llm("hello") # 9d0b2491-3d13-487c-bc02-1287f06ecae7 ``` </p> </details> ## Follow-up tasks (This PR is too large. I'll file a separate one for follow-up tasks.) - Update `docs/docs/integrations/providers/mlflow_ai_gateway.mdx` and `docs/docs/integrations/providers/databricks.md`. --------- Signed-off-by: harupy <17039389+harupy@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
dc31714ec5
commit
0d08a692a3
@ -24,6 +24,7 @@ from langchain.chat_models.baichuan import ChatBaichuan
|
|||||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||||
from langchain.chat_models.bedrock import BedrockChat
|
from langchain.chat_models.bedrock import BedrockChat
|
||||||
from langchain.chat_models.cohere import ChatCohere
|
from langchain.chat_models.cohere import ChatCohere
|
||||||
|
from langchain.chat_models.databricks import ChatDatabricks
|
||||||
from langchain.chat_models.ernie import ErnieBotChat
|
from langchain.chat_models.ernie import ErnieBotChat
|
||||||
from langchain.chat_models.everlyai import ChatEverlyAI
|
from langchain.chat_models.everlyai import ChatEverlyAI
|
||||||
from langchain.chat_models.fake import FakeListChatModel
|
from langchain.chat_models.fake import FakeListChatModel
|
||||||
@ -37,6 +38,7 @@ from langchain.chat_models.jinachat import JinaChat
|
|||||||
from langchain.chat_models.konko import ChatKonko
|
from langchain.chat_models.konko import ChatKonko
|
||||||
from langchain.chat_models.litellm import ChatLiteLLM
|
from langchain.chat_models.litellm import ChatLiteLLM
|
||||||
from langchain.chat_models.minimax import MiniMaxChat
|
from langchain.chat_models.minimax import MiniMaxChat
|
||||||
|
from langchain.chat_models.mlflow import ChatMlflow
|
||||||
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
|
||||||
from langchain.chat_models.ollama import ChatOllama
|
from langchain.chat_models.ollama import ChatOllama
|
||||||
from langchain.chat_models.openai import ChatOpenAI
|
from langchain.chat_models.openai import ChatOpenAI
|
||||||
@ -52,10 +54,12 @@ __all__ = [
|
|||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
"FakeListChatModel",
|
"FakeListChatModel",
|
||||||
"PromptLayerChatOpenAI",
|
"PromptLayerChatOpenAI",
|
||||||
|
"ChatDatabricks",
|
||||||
"ChatEverlyAI",
|
"ChatEverlyAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
"ChatCohere",
|
"ChatCohere",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
|
"ChatMlflow",
|
||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatVertexAI",
|
"ChatVertexAI",
|
||||||
|
46
libs/langchain/langchain/chat_models/databricks.py
Normal file
46
libs/langchain/langchain/chat_models/databricks.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import logging
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from langchain.chat_models.mlflow import ChatMlflow
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDatabricks(ChatMlflow):
|
||||||
|
"""`Databricks` chat models API.
|
||||||
|
|
||||||
|
To use, you should have the ``mlflow`` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatDatabricks
|
||||||
|
|
||||||
|
chat = ChatDatabricks(
|
||||||
|
target_uri="databricks",
|
||||||
|
endpoint="chat",
|
||||||
|
temperature-0.1,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_uri: str = "databricks"
|
||||||
|
"""The target URI to use. Defaults to ``databricks``."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "databricks-chat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _mlflow_extras(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _validate_uri(self) -> None:
|
||||||
|
if self.target_uri == "databricks":
|
||||||
|
return
|
||||||
|
|
||||||
|
if urlparse(self.target_uri).scheme != "databricks":
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||||
|
)
|
217
libs/langchain/langchain/chat_models/mlflow.py
Normal file
217
libs/langchain/langchain/chat_models/mlflow.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
BaseMessage,
|
||||||
|
ChatMessage,
|
||||||
|
FunctionMessage,
|
||||||
|
HumanMessage,
|
||||||
|
SystemMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import (
|
||||||
|
Field,
|
||||||
|
PrivateAttr,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatMlflow(BaseChatModel):
|
||||||
|
"""`MLflow` chat models API.
|
||||||
|
|
||||||
|
To use, you should have the `mlflow[genai]` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.chat_models import ChatMlflow
|
||||||
|
|
||||||
|
chat = ChatMlflow(
|
||||||
|
target_uri="http://localhost:5000",
|
||||||
|
endpoint="chat",
|
||||||
|
temperature-0.1,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str
|
||||||
|
"""The endpoint to use."""
|
||||||
|
target_uri: str
|
||||||
|
"""The target URI to use."""
|
||||||
|
temperature: float = 0.0
|
||||||
|
"""The sampling temperature."""
|
||||||
|
n: int = 1
|
||||||
|
"""The number of completion choices to generate."""
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
"""The stop sequence."""
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
"""The maximum number of tokens to generate."""
|
||||||
|
extra_params: dict = Field(default_factory=dict)
|
||||||
|
"""Any extra parameters to pass to the endpoint."""
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._validate_uri()
|
||||||
|
try:
|
||||||
|
from mlflow.deployments import get_deploy_client
|
||||||
|
|
||||||
|
self._client = get_deploy_client(self.target_uri)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to create the client. "
|
||||||
|
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
|
||||||
|
"required dependencies."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _mlflow_extras(self) -> str:
|
||||||
|
return "[genai]"
|
||||||
|
|
||||||
|
def _validate_uri(self) -> None:
|
||||||
|
if self.target_uri == "databricks":
|
||||||
|
return
|
||||||
|
allowed = ["http", "https", "databricks"]
|
||||||
|
if urlparse(self.target_uri).scheme not in allowed:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid target URI: {self.target_uri}. "
|
||||||
|
f"The scheme must be one of {allowed}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"target_uri": self.target_uri,
|
||||||
|
"endpoint": self.endpoint,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"extra_params": self.extra_params,
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts = [
|
||||||
|
ChatMlflow._convert_message_to_dict(message) for message in messages
|
||||||
|
]
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"messages": message_dicts,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": stop or self.stop,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
**self.extra_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||||
|
return ChatMlflow._create_chat_result(resp)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
func = partial(
|
||||||
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
|
def _get_invocation_params(
|
||||||
|
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Get the parameters used to invoke the model FOR THE CALLBACKS."""
|
||||||
|
return {
|
||||||
|
**self._default_params,
|
||||||
|
**super()._get_invocation_params(stop=stop, **kwargs),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "mlflow-chat"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||||
|
role = _dict["role"]
|
||||||
|
content = _dict["content"]
|
||||||
|
if role == "user":
|
||||||
|
return HumanMessage(content=content)
|
||||||
|
elif role == "assistant":
|
||||||
|
return AIMessage(content=content)
|
||||||
|
elif role == "system":
|
||||||
|
return SystemMessage(content=content)
|
||||||
|
else:
|
||||||
|
return ChatMessage(content=content, role=role)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _raise_functions_not_supported() -> None:
|
||||||
|
raise ValueError(
|
||||||
|
"Function messages are not supported by Databricks. Please"
|
||||||
|
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||||
|
if isinstance(message, ChatMessage):
|
||||||
|
message_dict = {"role": message.role, "content": message.content}
|
||||||
|
elif isinstance(message, HumanMessage):
|
||||||
|
message_dict = {"role": "user", "content": message.content}
|
||||||
|
elif isinstance(message, AIMessage):
|
||||||
|
message_dict = {"role": "assistant", "content": message.content}
|
||||||
|
elif isinstance(message, SystemMessage):
|
||||||
|
message_dict = {"role": "system", "content": message.content}
|
||||||
|
elif isinstance(message, FunctionMessage):
|
||||||
|
raise ValueError(
|
||||||
|
"Function messages are not supported by Databricks. Please"
|
||||||
|
" create a feature request at https://github.com/mlflow/mlflow/issues."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown message type: {message}")
|
||||||
|
|
||||||
|
if "function_call" in message.additional_kwargs:
|
||||||
|
ChatMlflow._raise_functions_not_supported()
|
||||||
|
if message.additional_kwargs:
|
||||||
|
logger.warning(
|
||||||
|
"Additional message arguments are unsupported by Databricks"
|
||||||
|
" and will be ignored: %s",
|
||||||
|
message.additional_kwargs,
|
||||||
|
)
|
||||||
|
return message_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for choice in response["choices"]:
|
||||||
|
message = ChatMlflow._convert_dict_to_message(choice["message"])
|
||||||
|
usage = choice.get("usage", {})
|
||||||
|
gen = ChatGeneration(
|
||||||
|
message=message,
|
||||||
|
generation_info=usage,
|
||||||
|
)
|
||||||
|
generations.append(gen)
|
||||||
|
|
||||||
|
usage = response.get("usage", {})
|
||||||
|
return ChatResult(generations=generations, llm_output=usage)
|
@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
@ -59,6 +60,11 @@ class ChatMLflowAIGateway(BaseChatModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
|
warnings.warn(
|
||||||
|
"`ChatMLflowAIGateway` is deprecated. Use `ChatMlflow` or "
|
||||||
|
"`ChatDatabricks` instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
import mlflow.gateway
|
import mlflow.gateway
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -26,6 +26,7 @@ from langchain.embeddings.cache import CacheBackedEmbeddings
|
|||||||
from langchain.embeddings.clarifai import ClarifaiEmbeddings
|
from langchain.embeddings.clarifai import ClarifaiEmbeddings
|
||||||
from langchain.embeddings.cohere import CohereEmbeddings
|
from langchain.embeddings.cohere import CohereEmbeddings
|
||||||
from langchain.embeddings.dashscope import DashScopeEmbeddings
|
from langchain.embeddings.dashscope import DashScopeEmbeddings
|
||||||
|
from langchain.embeddings.databricks import DatabricksEmbeddings
|
||||||
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
from langchain.embeddings.deepinfra import DeepInfraEmbeddings
|
||||||
from langchain.embeddings.edenai import EdenAiEmbeddings
|
from langchain.embeddings.edenai import EdenAiEmbeddings
|
||||||
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
from langchain.embeddings.elasticsearch import ElasticsearchEmbeddings
|
||||||
@ -50,6 +51,7 @@ from langchain.embeddings.johnsnowlabs import JohnSnowLabsEmbeddings
|
|||||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||||
from langchain.embeddings.localai import LocalAIEmbeddings
|
from langchain.embeddings.localai import LocalAIEmbeddings
|
||||||
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
from langchain.embeddings.minimax import MiniMaxEmbeddings
|
||||||
|
from langchain.embeddings.mlflow import MlflowEmbeddings
|
||||||
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
from langchain.embeddings.mlflow_gateway import MlflowAIGatewayEmbeddings
|
||||||
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
|
from langchain.embeddings.modelscope_hub import ModelScopeEmbeddings
|
||||||
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||||
@ -78,6 +80,7 @@ __all__ = [
|
|||||||
"CacheBackedEmbeddings",
|
"CacheBackedEmbeddings",
|
||||||
"ClarifaiEmbeddings",
|
"ClarifaiEmbeddings",
|
||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
|
"DatabricksEmbeddings",
|
||||||
"ElasticsearchEmbeddings",
|
"ElasticsearchEmbeddings",
|
||||||
"FastEmbedEmbeddings",
|
"FastEmbedEmbeddings",
|
||||||
"HuggingFaceEmbeddings",
|
"HuggingFaceEmbeddings",
|
||||||
@ -87,6 +90,7 @@ __all__ = [
|
|||||||
"JinaEmbeddings",
|
"JinaEmbeddings",
|
||||||
"LlamaCppEmbeddings",
|
"LlamaCppEmbeddings",
|
||||||
"HuggingFaceHubEmbeddings",
|
"HuggingFaceHubEmbeddings",
|
||||||
|
"MlflowEmbeddings",
|
||||||
"MlflowAIGatewayEmbeddings",
|
"MlflowAIGatewayEmbeddings",
|
||||||
"ModelScopeEmbeddings",
|
"ModelScopeEmbeddings",
|
||||||
"TensorflowHubEmbeddings",
|
"TensorflowHubEmbeddings",
|
||||||
|
45
libs/langchain/langchain/embeddings/databricks.py
Normal file
45
libs/langchain/langchain/embeddings/databricks.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Iterator, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from langchain.embeddings.mlflow import MlflowEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||||
|
for i in range(0, len(texts), size):
|
||||||
|
yield texts[i : i + size]
|
||||||
|
|
||||||
|
|
||||||
|
class DatabricksEmbeddings(MlflowEmbeddings):
|
||||||
|
"""Wrapper around embeddings LLMs in Databricks.
|
||||||
|
|
||||||
|
To use, you should have the ``mlflow`` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/llms/deployments/databricks.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import DatabricksEmbeddings
|
||||||
|
|
||||||
|
embeddings = DatabricksEmbeddings(
|
||||||
|
target_uri="databricks",
|
||||||
|
endpoint="embeddings",
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_uri: str = "databricks"
|
||||||
|
"""The target URI to use. Defaults to ``databricks``."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _mlflow_extras(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _validate_uri(self) -> None:
|
||||||
|
if self.target_uri == "databricks":
|
||||||
|
return
|
||||||
|
|
||||||
|
if urlparse(self.target_uri).scheme != "databricks":
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid target URI. The target URI must be a valid databricks URI."
|
||||||
|
)
|
74
libs/langchain/langchain/embeddings/mlflow.py
Normal file
74
libs/langchain/langchain/embeddings/mlflow.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Iterator, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk(texts: List[str], size: int) -> Iterator[List[str]]:
|
||||||
|
for i in range(0, len(texts), size):
|
||||||
|
yield texts[i : i + size]
|
||||||
|
|
||||||
|
|
||||||
|
class MlflowEmbeddings(Embeddings, BaseModel):
|
||||||
|
"""Wrapper around embeddings LLMs in MLflow.
|
||||||
|
|
||||||
|
To use, you should have the `mlflow[genai]` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.embeddings import MlflowEmbeddings
|
||||||
|
|
||||||
|
embeddings = MlflowEmbeddings(
|
||||||
|
target_uri="http://localhost:5000",
|
||||||
|
endpoint="embeddings",
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str
|
||||||
|
"""The endpoint to use."""
|
||||||
|
target_uri: str
|
||||||
|
"""The target URI to use."""
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._validate_uri()
|
||||||
|
try:
|
||||||
|
from mlflow.deployments import get_deploy_client
|
||||||
|
|
||||||
|
self._client = get_deploy_client(self.target_uri)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to create the client. "
|
||||||
|
f"Please run `pip install mlflow{self._mlflow_extras}` to install "
|
||||||
|
"required dependencies."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _mlflow_extras(self) -> str:
|
||||||
|
return "[genai]"
|
||||||
|
|
||||||
|
def _validate_uri(self) -> None:
|
||||||
|
if self.target_uri == "databricks":
|
||||||
|
return
|
||||||
|
allowed = ["http", "https", "databricks"]
|
||||||
|
if urlparse(self.target_uri).scheme not in allowed:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid target URI: {self.target_uri}. "
|
||||||
|
f"The scheme must be one of {allowed}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
embeddings: List[List[float]] = []
|
||||||
|
for txt in _chunk(texts, 20):
|
||||||
|
resp = self._client.predict(endpoint=self.endpoint, inputs={"input": txt})
|
||||||
|
embeddings.extend(r["embedding"] for r in resp["data"])
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def embed_query(self, text: str) -> List[float]:
|
||||||
|
return self.embed_documents([text])[0]
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Any, Iterator, List, Optional
|
from typing import Any, Iterator, List, Optional
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -35,6 +36,11 @@ class MlflowAIGatewayEmbeddings(Embeddings, BaseModel):
|
|||||||
"""The URI for the MLflow AI Gateway API."""
|
"""The URI for the MLflow AI Gateway API."""
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
|
warnings.warn(
|
||||||
|
"`MlflowAIGatewayEmbeddings` is deprecated. Use `MlflowEmbeddings` or "
|
||||||
|
"`DatabricksEmbeddings` instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
import mlflow.gateway
|
import mlflow.gateway
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -148,6 +148,12 @@ def _import_databricks() -> Any:
|
|||||||
return Databricks
|
return Databricks
|
||||||
|
|
||||||
|
|
||||||
|
def _import_databricks_chat() -> Any:
|
||||||
|
from langchain.chat_models.databricks import ChatDatabricks
|
||||||
|
|
||||||
|
return ChatDatabricks
|
||||||
|
|
||||||
|
|
||||||
def _import_deepinfra() -> Any:
|
def _import_deepinfra() -> Any:
|
||||||
from langchain.llms.deepinfra import DeepInfra
|
from langchain.llms.deepinfra import DeepInfra
|
||||||
|
|
||||||
@ -276,6 +282,18 @@ def _import_minimax() -> Any:
|
|||||||
return Minimax
|
return Minimax
|
||||||
|
|
||||||
|
|
||||||
|
def _import_mlflow() -> Any:
|
||||||
|
from langchain.llms.mlflow import Mlflow
|
||||||
|
|
||||||
|
return Mlflow
|
||||||
|
|
||||||
|
|
||||||
|
def _import_mlflow_chat() -> Any:
|
||||||
|
from langchain.chat_models.mlflow import ChatMlflow
|
||||||
|
|
||||||
|
return ChatMlflow
|
||||||
|
|
||||||
|
|
||||||
def _import_mlflow_ai_gateway() -> Any:
|
def _import_mlflow_ai_gateway() -> Any:
|
||||||
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
|
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
|
||||||
|
|
||||||
@ -595,6 +613,8 @@ def __getattr__(name: str) -> Any:
|
|||||||
return _import_manifest()
|
return _import_manifest()
|
||||||
elif name == "Minimax":
|
elif name == "Minimax":
|
||||||
return _import_minimax()
|
return _import_minimax()
|
||||||
|
elif name == "Mlflow":
|
||||||
|
return _import_mlflow()
|
||||||
elif name == "MlflowAIGateway":
|
elif name == "MlflowAIGateway":
|
||||||
return _import_mlflow_ai_gateway()
|
return _import_mlflow_ai_gateway()
|
||||||
elif name == "Modal":
|
elif name == "Modal":
|
||||||
@ -789,6 +809,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"ctransformers": _import_ctransformers,
|
"ctransformers": _import_ctransformers,
|
||||||
"ctranslate2": _import_ctranslate2,
|
"ctranslate2": _import_ctranslate2,
|
||||||
"databricks": _import_databricks,
|
"databricks": _import_databricks,
|
||||||
|
"databricks-chat": _import_databricks_chat,
|
||||||
"deepinfra": _import_deepinfra,
|
"deepinfra": _import_deepinfra,
|
||||||
"deepsparse": _import_deepsparse,
|
"deepsparse": _import_deepsparse,
|
||||||
"edenai": _import_edenai,
|
"edenai": _import_edenai,
|
||||||
@ -808,6 +829,8 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
|||||||
"llamacpp": _import_llamacpp,
|
"llamacpp": _import_llamacpp,
|
||||||
"textgen": _import_textgen,
|
"textgen": _import_textgen,
|
||||||
"minimax": _import_minimax,
|
"minimax": _import_minimax,
|
||||||
|
"mlflow": _import_mlflow,
|
||||||
|
"mlflow-chat": _import_mlflow_chat,
|
||||||
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
|
"mlflow-ai-gateway": _import_mlflow_ai_gateway,
|
||||||
"modal": _import_modal,
|
"modal": _import_modal,
|
||||||
"mosaic": _import_mosaicml,
|
"mosaic": _import_mosaicml,
|
||||||
|
@ -1,8 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LLM
|
||||||
from langchain_core.pydantic_v1 import (
|
from langchain_core.pydantic_v1 import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Extra,
|
Extra,
|
||||||
@ -12,9 +15,6 @@ from langchain_core.pydantic_v1 import (
|
|||||||
validator,
|
validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
|
||||||
from langchain.llms.base import LLM
|
|
||||||
|
|
||||||
__all__ = ["Databricks"]
|
__all__ = ["Databricks"]
|
||||||
|
|
||||||
|
|
||||||
@ -24,24 +24,66 @@ class _DatabricksClientBase(BaseModel, ABC):
|
|||||||
api_url: str
|
api_url: str
|
||||||
api_token: str
|
api_token: str
|
||||||
|
|
||||||
def post_raw(self, request: Any) -> Any:
|
def request(self, method: str, url: str, request: Any) -> Any:
|
||||||
headers = {"Authorization": f"Bearer {self.api_token}"}
|
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||||||
response = requests.post(self.api_url, headers=headers, json=request)
|
response = requests.request(
|
||||||
|
method=method, url=url, headers=headers, json=request
|
||||||
|
)
|
||||||
# TODO: error handling and automatic retries
|
# TODO: error handling and automatic retries
|
||||||
if not response.ok:
|
if not response.ok:
|
||||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
def _get(self, url: str) -> Any:
|
||||||
|
return self.request("GET", url, None)
|
||||||
|
|
||||||
|
def _post(self, url: str, request: Any) -> Any:
|
||||||
|
return self.request("POST", url, request)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def post(self, request: Any) -> Any:
|
def post(
|
||||||
|
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
||||||
|
) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_completions(response: Dict[str, Any]) -> str:
|
||||||
|
return response["choices"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
def _transform_chat(response: Dict[str, Any]) -> str:
|
||||||
|
return response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
||||||
"""An API client that talks to a Databricks serving endpoint."""
|
"""An API client that talks to a Databricks serving endpoint."""
|
||||||
|
|
||||||
host: str
|
host: str
|
||||||
endpoint_name: str
|
endpoint_name: str
|
||||||
|
databricks_uri: str
|
||||||
|
client: Any = None
|
||||||
|
external_or_foundation: bool = False
|
||||||
|
task: Optional[str] = None
|
||||||
|
|
||||||
|
def __init__(self, **data: Any):
|
||||||
|
super().__init__(**data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from mlflow.deployments import get_deploy_client
|
||||||
|
|
||||||
|
self.client = get_deploy_client(self.databricks_uri)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to create the client. "
|
||||||
|
"Please install mlflow with `pip install mlflow`."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
endpoint = self.client.get_endpoint(self.endpoint_name)
|
||||||
|
self.external_or_foundation = endpoint.get("endpoint_type", "").lower() in (
|
||||||
|
"external_model",
|
||||||
|
"foundation_model_api",
|
||||||
|
)
|
||||||
|
self.task = endpoint.get("task")
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
def set_api_url(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
@ -52,14 +94,30 @@ class _DatabricksServingEndpointClient(_DatabricksClientBase):
|
|||||||
values["api_url"] = api_url
|
values["api_url"] = api_url
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def post(self, request: Any) -> Any:
|
def post(
|
||||||
|
self, request: Any, transform_output_fn: Optional[Callable[..., str]] = None
|
||||||
|
) -> Any:
|
||||||
|
if self.external_or_foundation:
|
||||||
|
resp = self.client.predict(endpoint=self.endpoint_name, inputs=request)
|
||||||
|
if transform_output_fn:
|
||||||
|
return transform_output_fn(resp)
|
||||||
|
|
||||||
|
if self.task == "llm/v1/chat":
|
||||||
|
return _transform_chat(resp)
|
||||||
|
elif self.task == "llm/v1/completions":
|
||||||
|
return _transform_completions(resp)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
else:
|
||||||
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
# See https://docs.databricks.com/machine-learning/model-serving/score-model-serving-endpoints.html
|
||||||
wrapped_request = {"dataframe_records": [request]}
|
wrapped_request = {"dataframe_records": [request]}
|
||||||
response = self.post_raw(wrapped_request)["predictions"]
|
response = self.client.predict(
|
||||||
|
endpoint=self.endpoint_name, inputs=wrapped_request
|
||||||
|
)
|
||||||
|
preds = response["predictions"]
|
||||||
# For a single-record query, the result is not a list.
|
# For a single-record query, the result is not a list.
|
||||||
if isinstance(response, list):
|
pred = preds[0] if isinstance(preds, list) else preds
|
||||||
response = response[0]
|
return transform_output_fn(pred) if transform_output_fn else pred
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
||||||
@ -79,8 +137,8 @@ class _DatabricksClusterDriverProxyClient(_DatabricksClientBase):
|
|||||||
values["api_url"] = api_url
|
values["api_url"] = api_url
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def post(self, request: Any) -> Any:
|
def post(self, request: Any, transform: Optional[Callable[..., str]] = None) -> Any:
|
||||||
return self.post_raw(request)
|
return self._post(self.api_url, request)
|
||||||
|
|
||||||
|
|
||||||
def get_repl_context() -> Any:
|
def get_repl_context() -> Any:
|
||||||
@ -137,16 +195,19 @@ def get_default_api_token() -> str:
|
|||||||
|
|
||||||
|
|
||||||
class Databricks(LLM):
|
class Databricks(LLM):
|
||||||
|
|
||||||
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
"""Databricks serving endpoint or a cluster driver proxy app for LLM.
|
||||||
|
|
||||||
It supports two endpoint types:
|
It supports two endpoint types:
|
||||||
|
|
||||||
* **Serving endpoint** (recommended for both production and development).
|
* **Serving endpoint** (recommended for both production and development).
|
||||||
We assume that an LLM was registered and deployed to a serving endpoint.
|
We assume that an LLM was deployed to a serving endpoint.
|
||||||
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
To wrap it as an LLM you must have "Can Query" permission to the endpoint.
|
||||||
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
Set ``endpoint_name`` accordingly and do not set ``cluster_id`` and
|
||||||
``cluster_driver_port``.
|
``cluster_driver_port``.
|
||||||
The expected model signature is:
|
|
||||||
|
If the underlying model is a model registered by MLflow, the expected model
|
||||||
|
signature is:
|
||||||
|
|
||||||
* inputs::
|
* inputs::
|
||||||
|
|
||||||
@ -155,6 +216,10 @@ class Databricks(LLM):
|
|||||||
|
|
||||||
* outputs: ``[{"type": "string"}]``
|
* outputs: ``[{"type": "string"}]``
|
||||||
|
|
||||||
|
If the underlying model is an external or foundation model, the response from the
|
||||||
|
endpoint is automatically transformed to the expected format unless
|
||||||
|
``transform_output_fn`` is provided.
|
||||||
|
|
||||||
* **Cluster driver proxy app** (recommended for interactive development).
|
* **Cluster driver proxy app** (recommended for interactive development).
|
||||||
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
One can load an LLM on a Databricks interactive cluster and start a local HTTP
|
||||||
server on the driver node to serve the model at ``/`` using HTTP POST method
|
server on the driver node to serve the model at ``/`` using HTTP POST method
|
||||||
@ -220,9 +285,14 @@ class Databricks(LLM):
|
|||||||
We recommend the server using a port number between ``[3000, 8000]``.
|
We recommend the server using a port number between ``[3000, 8000]``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_kwargs: Optional[Dict[str, Any]] = None
|
params: Optional[Dict[str, Any]] = None
|
||||||
"""Extra parameters to pass to the endpoint."""
|
"""Extra parameters to pass to the endpoint."""
|
||||||
|
|
||||||
|
model_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
"""
|
||||||
|
Deprecated. Please use ``params`` instead. Extra parameters to pass to the endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
transform_input_fn: Optional[Callable] = None
|
transform_input_fn: Optional[Callable] = None
|
||||||
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
"""A function that transforms ``{prompt, stop, **kwargs}`` into a JSON-compatible
|
||||||
request object that the endpoint accepts.
|
request object that the endpoint accepts.
|
||||||
@ -233,6 +303,9 @@ class Databricks(LLM):
|
|||||||
"""A function that transforms the output from the endpoint to the generated text.
|
"""A function that transforms the output from the endpoint to the generated text.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
databricks_uri: str = "databricks"
|
||||||
|
"""The databricks URI. Only used when using a serving endpoint."""
|
||||||
|
|
||||||
_client: _DatabricksClientBase = PrivateAttr()
|
_client: _DatabricksClientBase = PrivateAttr()
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -283,11 +356,19 @@ class Databricks(LLM):
|
|||||||
|
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
if self.model_kwargs is not None and self.params is not None:
|
||||||
|
raise ValueError("Cannot set both model_kwargs and params.")
|
||||||
|
elif self.model_kwargs is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"model_kwargs is deprecated. Please use params instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
if self.endpoint_name:
|
if self.endpoint_name:
|
||||||
self._client = _DatabricksServingEndpointClient(
|
self._client = _DatabricksServingEndpointClient(
|
||||||
host=self.host,
|
host=self.host,
|
||||||
api_token=self.api_token,
|
api_token=self.api_token,
|
||||||
endpoint_name=self.endpoint_name,
|
endpoint_name=self.endpoint_name,
|
||||||
|
databricks_uri=self.databricks_uri,
|
||||||
)
|
)
|
||||||
elif self.cluster_id and self.cluster_driver_port:
|
elif self.cluster_id and self.cluster_driver_port:
|
||||||
self._client = _DatabricksClusterDriverProxyClient(
|
self._client = _DatabricksClusterDriverProxyClient(
|
||||||
@ -301,6 +382,31 @@ class Databricks(LLM):
|
|||||||
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
"Must specify either endpoint_name or cluster_id/cluster_driver_port."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _params(self) -> Optional[Dict[str, Any]]:
|
||||||
|
return self.model_kwargs or self.params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
"""Return default params."""
|
||||||
|
return {
|
||||||
|
"host": self.host,
|
||||||
|
# "api_token": self.api_token, # Never save the token
|
||||||
|
"endpoint_name": self.endpoint_name,
|
||||||
|
"cluster_id": self.cluster_id,
|
||||||
|
"cluster_driver_port": self.cluster_driver_port,
|
||||||
|
"databricks_uri": self.databricks_uri,
|
||||||
|
"model_kwargs": self.model_kwargs,
|
||||||
|
"params": self.params,
|
||||||
|
# TODO: Support saving transform_input_fn and transform_output_fn
|
||||||
|
# "transform_input_fn": self.transform_input_fn,
|
||||||
|
# "transform_output_fn": self.transform_output_fn,
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
@ -319,8 +425,8 @@ class Databricks(LLM):
|
|||||||
|
|
||||||
request = {"prompt": prompt, "stop": stop}
|
request = {"prompt": prompt, "stop": stop}
|
||||||
request.update(kwargs)
|
request.update(kwargs)
|
||||||
if self.model_kwargs:
|
if self._params:
|
||||||
request.update(self.model_kwargs)
|
request.update(self._params)
|
||||||
|
|
||||||
if self.transform_input_fn:
|
if self.transform_input_fn:
|
||||||
request = self.transform_input_fn(**request)
|
request = self.transform_input_fn(**request)
|
||||||
|
104
libs/langchain/langchain/llms/mlflow.py
Normal file
104
libs/langchain/langchain/llms/mlflow.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import LLM
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel, Extra, PrivateAttr
|
||||||
|
|
||||||
|
|
||||||
|
# Ignoring type because below is valid pydantic code
|
||||||
|
# Unexpected keyword argument "extra" for "__init_subclass__" of "object"
|
||||||
|
class Params(BaseModel, extra=Extra.allow): # type: ignore[call-arg]
|
||||||
|
"""Parameters for MLflow"""
|
||||||
|
|
||||||
|
temperature: float = 0.0
|
||||||
|
n: int = 1
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Mlflow(LLM):
|
||||||
|
"""Wrapper around completions LLMs in MLflow.
|
||||||
|
|
||||||
|
To use, you should have the `mlflow[genai]` python package installed.
|
||||||
|
For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain.llms import Mlflow
|
||||||
|
|
||||||
|
completions = Mlflow(
|
||||||
|
target_uri="http://localhost:5000",
|
||||||
|
endpoint="test",
|
||||||
|
params={"temperature": 0.1}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
endpoint: str
|
||||||
|
"""The endpoint to use."""
|
||||||
|
target_uri: str
|
||||||
|
"""The target URI to use."""
|
||||||
|
params: Optional[Params] = None
|
||||||
|
"""Extra parameters such as `temperature`."""
|
||||||
|
_client: Any = PrivateAttr()
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._validate_uri()
|
||||||
|
try:
|
||||||
|
from mlflow.deployments import get_deploy_client
|
||||||
|
|
||||||
|
self._client = get_deploy_client(self.target_uri)
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Failed to create the client. "
|
||||||
|
"Please run `pip install mlflow[genai]` to install "
|
||||||
|
"required dependencies."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _validate_uri(self) -> None:
|
||||||
|
if self.target_uri == "databricks":
|
||||||
|
return
|
||||||
|
allowed = ["http", "https", "databricks"]
|
||||||
|
if urlparse(self.target_uri).scheme not in allowed:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid target URI: {self.target_uri}. "
|
||||||
|
f"The scheme must be one of {allowed}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"target_uri": self.target_uri,
|
||||||
|
"endpoint": self.endpoint,
|
||||||
|
}
|
||||||
|
if self.params:
|
||||||
|
params["params"] = self.params.dict()
|
||||||
|
return params
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
return self._default_params
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"prompt": prompt,
|
||||||
|
**(self.params.dict() if self.params else {}),
|
||||||
|
}
|
||||||
|
if s := (stop or (self.params.stop if self.params else None)):
|
||||||
|
data["stop"] = s
|
||||||
|
resp = self._client.predict(endpoint=self.endpoint, inputs=data)
|
||||||
|
return resp["choices"][0]["text"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "mlflow"
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||||
@ -46,6 +47,10 @@ class MlflowAIGateway(LLM):
|
|||||||
params: Optional[Params] = None
|
params: Optional[Params] = None
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any):
|
def __init__(self, **kwargs: Any):
|
||||||
|
warnings.warn(
|
||||||
|
"`MlflowAIGateway` is deprecated. Use `Mlflow` or `Databricks` instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
import mlflow.gateway
|
import mlflow.gateway
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
|
@ -9,7 +9,9 @@ EXPECTED_ALL = [
|
|||||||
"ChatEverlyAI",
|
"ChatEverlyAI",
|
||||||
"ChatAnthropic",
|
"ChatAnthropic",
|
||||||
"ChatCohere",
|
"ChatCohere",
|
||||||
|
"ChatDatabricks",
|
||||||
"ChatGooglePalm",
|
"ChatGooglePalm",
|
||||||
|
"ChatMlflow",
|
||||||
"ChatMLflowAIGateway",
|
"ChatMLflowAIGateway",
|
||||||
"ChatOllama",
|
"ChatOllama",
|
||||||
"ChatVertexAI",
|
"ChatVertexAI",
|
||||||
|
@ -6,6 +6,7 @@ EXPECTED_ALL = [
|
|||||||
"CacheBackedEmbeddings",
|
"CacheBackedEmbeddings",
|
||||||
"ClarifaiEmbeddings",
|
"ClarifaiEmbeddings",
|
||||||
"CohereEmbeddings",
|
"CohereEmbeddings",
|
||||||
|
"DatabricksEmbeddings",
|
||||||
"ElasticsearchEmbeddings",
|
"ElasticsearchEmbeddings",
|
||||||
"FastEmbedEmbeddings",
|
"FastEmbedEmbeddings",
|
||||||
"HuggingFaceEmbeddings",
|
"HuggingFaceEmbeddings",
|
||||||
@ -16,6 +17,7 @@ EXPECTED_ALL = [
|
|||||||
"LlamaCppEmbeddings",
|
"LlamaCppEmbeddings",
|
||||||
"HuggingFaceHubEmbeddings",
|
"HuggingFaceHubEmbeddings",
|
||||||
"MlflowAIGatewayEmbeddings",
|
"MlflowAIGatewayEmbeddings",
|
||||||
|
"MlflowEmbeddings",
|
||||||
"ModelScopeEmbeddings",
|
"ModelScopeEmbeddings",
|
||||||
"TensorflowHubEmbeddings",
|
"TensorflowHubEmbeddings",
|
||||||
"SagemakerEndpointEmbeddings",
|
"SagemakerEndpointEmbeddings",
|
||||||
|
Loading…
Reference in New Issue
Block a user