diff --git a/libs/partners/mistralai/langchain_mistralai/chat_models.py b/libs/partners/mistralai/langchain_mistralai/chat_models.py index 71d9ef79841..bf5ca6b8673 100644 --- a/libs/partners/mistralai/langchain_mistralai/chat_models.py +++ b/libs/partners/mistralai/langchain_mistralai/chat_models.py @@ -42,8 +42,8 @@ from langchain_core.outputs import ( ChatGenerationChunk, ChatResult, ) -from langchain_core.pydantic_v1 import root_validator -from langchain_core.utils import get_from_dict_or_env +from langchain_core.pydantic_v1 import SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env # TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker. from mistralai.async_client import MistralAsyncClient # type: ignore[import] @@ -111,18 +111,11 @@ async def acompletion_with_retry( @retry_decorator async def _completion_with_retry(**kwargs: Any) -> Any: - client = MistralAsyncClient( - api_key=llm.mistral_api_key, - endpoint=llm.endpoint, - max_retries=llm.max_retries, - timeout=llm.timeout, - max_concurrent_requests=llm.max_concurrent_requests, - ) stream = kwargs.pop("stream", False) if stream: - return client.chat_stream(**kwargs) + return llm.async_client.chat_stream(**kwargs) else: - return await client.chat(**kwargs) + return await llm.async_client.chat(**kwargs) return await _completion_with_retry(**kwargs) @@ -163,8 +156,9 @@ def _convert_message_to_mistral_chat_message( class ChatMistralAI(BaseChatModel): """A chat model that uses the MistralAI API.""" - client: Any #: :meta private: - mistral_api_key: Optional[str] = None + client: MistralClient = None #: :meta private: + async_client: MistralAsyncClient = None #: :meta private: + mistral_api_key: Optional[SecretStr] = None endpoint: str = DEFAULT_MISTRAL_ENDPOINT max_retries: int = 5 timeout: int = 120 @@ -224,15 +218,24 @@ class ChatMistralAI(BaseChatModel): "Please install it with `pip install mistralai`" ) - values["mistral_api_key"] = get_from_dict_or_env( - values, "mistral_api_key", "MISTRAL_API_KEY", default="" + values["mistral_api_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, "mistral_api_key", "MISTRAL_API_KEY", default="" + ) ) values["client"] = MistralClient( - api_key=values["mistral_api_key"], + api_key=values["mistral_api_key"].get_secret_value(), endpoint=values["endpoint"], max_retries=values["max_retries"], timeout=values["timeout"], ) + values["async_client"] = MistralAsyncClient( + api_key=values["mistral_api_key"].get_secret_value(), + endpoint=values["endpoint"], + max_retries=values["max_retries"], + timeout=values["timeout"], + max_concurrent_requests=values["max_concurrent_requests"], + ) if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: raise ValueError("temperature must be in the range [0.0, 1.0]")