mirror of
				https://github.com/hwchase17/langchain.git
				synced 2025-10-26 13:21:40 +00:00 
			
		
		
		
	- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars
		
			
				
	
	
		
			205 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			205 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| from typing import Any, Dict, List, Mapping, Optional
 | |
| from urllib.parse import urlparse
 | |
| 
 | |
| from langchain_core.callbacks import (
 | |
|     CallbackManagerForLLMRun,
 | |
| )
 | |
| from langchain_core.language_models import BaseChatModel
 | |
| from langchain_core.messages import (
 | |
|     AIMessage,
 | |
|     BaseMessage,
 | |
|     ChatMessage,
 | |
|     FunctionMessage,
 | |
|     HumanMessage,
 | |
|     SystemMessage,
 | |
| )
 | |
| from langchain_core.outputs import ChatGeneration, ChatResult
 | |
| from langchain_core.pydantic_v1 import (
 | |
|     Field,
 | |
|     PrivateAttr,
 | |
| )
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| class ChatMlflow(BaseChatModel):
 | |
|     """`MLflow` chat models API.
 | |
| 
 | |
|     To use, you should have the `mlflow[genai]` python package installed.
 | |
|     For more information, see https://mlflow.org/docs/latest/llms/deployments/server.html.
 | |
| 
 | |
|     Example:
 | |
|         .. code-block:: python
 | |
| 
 | |
|             from langchain_community.chat_models import ChatMlflow
 | |
| 
 | |
|             chat = ChatMlflow(
 | |
|                 target_uri="http://localhost:5000",
 | |
|                 endpoint="chat",
 | |
|                 temperature-0.1,
 | |
|             )
 | |
|     """
 | |
| 
 | |
|     endpoint: str
 | |
|     """The endpoint to use."""
 | |
|     target_uri: str
 | |
|     """The target URI to use."""
 | |
|     temperature: float = 0.0
 | |
|     """The sampling temperature."""
 | |
|     n: int = 1
 | |
|     """The number of completion choices to generate."""
 | |
|     stop: Optional[List[str]] = None
 | |
|     """The stop sequence."""
 | |
|     max_tokens: Optional[int] = None
 | |
|     """The maximum number of tokens to generate."""
 | |
|     extra_params: dict = Field(default_factory=dict)
 | |
|     """Any extra parameters to pass to the endpoint."""
 | |
|     _client: Any = PrivateAttr()
 | |
| 
 | |
|     def __init__(self, **kwargs: Any):
 | |
|         super().__init__(**kwargs)
 | |
|         self._validate_uri()
 | |
|         try:
 | |
|             from mlflow.deployments import get_deploy_client
 | |
| 
 | |
|             self._client = get_deploy_client(self.target_uri)
 | |
|         except ImportError as e:
 | |
|             raise ImportError(
 | |
|                 "Failed to create the client. "
 | |
|                 f"Please run `pip install mlflow{self._mlflow_extras}` to install "
 | |
|                 "required dependencies."
 | |
|             ) from e
 | |
| 
 | |
|     @property
 | |
|     def _mlflow_extras(self) -> str:
 | |
|         return "[genai]"
 | |
| 
 | |
|     def _validate_uri(self) -> None:
 | |
|         if self.target_uri == "databricks":
 | |
|             return
 | |
|         allowed = ["http", "https", "databricks"]
 | |
|         if urlparse(self.target_uri).scheme not in allowed:
 | |
|             raise ValueError(
 | |
|                 f"Invalid target URI: {self.target_uri}. "
 | |
|                 f"The scheme must be one of {allowed}."
 | |
|             )
 | |
| 
 | |
|     @property
 | |
|     def _default_params(self) -> Dict[str, Any]:
 | |
|         params: Dict[str, Any] = {
 | |
|             "target_uri": self.target_uri,
 | |
|             "endpoint": self.endpoint,
 | |
|             "temperature": self.temperature,
 | |
|             "n": self.n,
 | |
|             "stop": self.stop,
 | |
|             "max_tokens": self.max_tokens,
 | |
|             "extra_params": self.extra_params,
 | |
|         }
 | |
|         return params
 | |
| 
 | |
|     def _generate(
 | |
|         self,
 | |
|         messages: List[BaseMessage],
 | |
|         stop: Optional[List[str]] = None,
 | |
|         run_manager: Optional[CallbackManagerForLLMRun] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> ChatResult:
 | |
|         message_dicts = [
 | |
|             ChatMlflow._convert_message_to_dict(message) for message in messages
 | |
|         ]
 | |
|         data: Dict[str, Any] = {
 | |
|             "messages": message_dicts,
 | |
|             "temperature": self.temperature,
 | |
|             "n": self.n,
 | |
|             **self.extra_params,
 | |
|             **kwargs,
 | |
|         }
 | |
|         if stop := self.stop or stop:
 | |
|             data["stop"] = stop
 | |
|         if self.max_tokens is not None:
 | |
|             data["max_tokens"] = self.max_tokens
 | |
|         resp = self._client.predict(endpoint=self.endpoint, inputs=data)
 | |
|         return ChatMlflow._create_chat_result(resp)
 | |
| 
 | |
|     @property
 | |
|     def _identifying_params(self) -> Dict[str, Any]:
 | |
|         return self._default_params
 | |
| 
 | |
|     def _get_invocation_params(
 | |
|         self, stop: Optional[List[str]] = None, **kwargs: Any
 | |
|     ) -> Dict[str, Any]:
 | |
|         """Get the parameters used to invoke the model FOR THE CALLBACKS."""
 | |
|         return {
 | |
|             **self._default_params,
 | |
|             **super()._get_invocation_params(stop=stop, **kwargs),
 | |
|         }
 | |
| 
 | |
|     @property
 | |
|     def _llm_type(self) -> str:
 | |
|         """Return type of chat model."""
 | |
|         return "mlflow-chat"
 | |
| 
 | |
|     @staticmethod
 | |
|     def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
 | |
|         role = _dict["role"]
 | |
|         content = _dict["content"]
 | |
|         if role == "user":
 | |
|             return HumanMessage(content=content)
 | |
|         elif role == "assistant":
 | |
|             return AIMessage(content=content)
 | |
|         elif role == "system":
 | |
|             return SystemMessage(content=content)
 | |
|         else:
 | |
|             return ChatMessage(content=content, role=role)
 | |
| 
 | |
|     @staticmethod
 | |
|     def _raise_functions_not_supported() -> None:
 | |
|         raise ValueError(
 | |
|             "Function messages are not supported by Databricks. Please"
 | |
|             " create a feature request at https://github.com/mlflow/mlflow/issues."
 | |
|         )
 | |
| 
 | |
|     @staticmethod
 | |
|     def _convert_message_to_dict(message: BaseMessage) -> dict:
 | |
|         if isinstance(message, ChatMessage):
 | |
|             message_dict = {"role": message.role, "content": message.content}
 | |
|         elif isinstance(message, HumanMessage):
 | |
|             message_dict = {"role": "user", "content": message.content}
 | |
|         elif isinstance(message, AIMessage):
 | |
|             message_dict = {"role": "assistant", "content": message.content}
 | |
|         elif isinstance(message, SystemMessage):
 | |
|             message_dict = {"role": "system", "content": message.content}
 | |
|         elif isinstance(message, FunctionMessage):
 | |
|             raise ValueError(
 | |
|                 "Function messages are not supported by Databricks. Please"
 | |
|                 " create a feature request at https://github.com/mlflow/mlflow/issues."
 | |
|             )
 | |
|         else:
 | |
|             raise ValueError(f"Got unknown message type: {message}")
 | |
| 
 | |
|         if "function_call" in message.additional_kwargs:
 | |
|             ChatMlflow._raise_functions_not_supported()
 | |
|         if message.additional_kwargs:
 | |
|             logger.warning(
 | |
|                 "Additional message arguments are unsupported by Databricks"
 | |
|                 " and will be ignored: %s",
 | |
|                 message.additional_kwargs,
 | |
|             )
 | |
|         return message_dict
 | |
| 
 | |
|     @staticmethod
 | |
|     def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
 | |
|         generations = []
 | |
|         for choice in response["choices"]:
 | |
|             message = ChatMlflow._convert_dict_to_message(choice["message"])
 | |
|             usage = choice.get("usage", {})
 | |
|             gen = ChatGeneration(
 | |
|                 message=message,
 | |
|                 generation_info=usage,
 | |
|             )
 | |
|             generations.append(gen)
 | |
| 
 | |
|         usage = response.get("usage", {})
 | |
|         return ChatResult(generations=generations, llm_output=usage)
 |