mistralai[patch]: persist async client (#15786)

This commit is contained in:
Erick Friis 2024-01-09 16:21:39 -08:00 committed by GitHub
parent 3e0cd11f51
commit 323941a90a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -42,8 +42,8 @@ from langchain_core.outputs import (
ChatGenerationChunk, ChatGenerationChunk,
ChatResult, ChatResult,
) )
from langchain_core.pydantic_v1 import root_validator from langchain_core.pydantic_v1 import SecretStr, root_validator
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker. # TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
from mistralai.async_client import MistralAsyncClient # type: ignore[import] from mistralai.async_client import MistralAsyncClient # type: ignore[import]
@ -111,18 +111,11 @@ async def acompletion_with_retry(
@retry_decorator @retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any: async def _completion_with_retry(**kwargs: Any) -> Any:
client = MistralAsyncClient(
api_key=llm.mistral_api_key,
endpoint=llm.endpoint,
max_retries=llm.max_retries,
timeout=llm.timeout,
max_concurrent_requests=llm.max_concurrent_requests,
)
stream = kwargs.pop("stream", False) stream = kwargs.pop("stream", False)
if stream: if stream:
return client.chat_stream(**kwargs) return llm.async_client.chat_stream(**kwargs)
else: else:
return await client.chat(**kwargs) return await llm.async_client.chat(**kwargs)
return await _completion_with_retry(**kwargs) return await _completion_with_retry(**kwargs)
@ -163,8 +156,9 @@ def _convert_message_to_mistral_chat_message(
class ChatMistralAI(BaseChatModel): class ChatMistralAI(BaseChatModel):
"""A chat model that uses the MistralAI API.""" """A chat model that uses the MistralAI API."""
client: Any #: :meta private: client: MistralClient = None #: :meta private:
mistral_api_key: Optional[str] = None async_client: MistralAsyncClient = None #: :meta private:
mistral_api_key: Optional[SecretStr] = None
endpoint: str = DEFAULT_MISTRAL_ENDPOINT endpoint: str = DEFAULT_MISTRAL_ENDPOINT
max_retries: int = 5 max_retries: int = 5
timeout: int = 120 timeout: int = 120
@ -224,15 +218,24 @@ class ChatMistralAI(BaseChatModel):
"Please install it with `pip install mistralai`" "Please install it with `pip install mistralai`"
) )
values["mistral_api_key"] = get_from_dict_or_env( values["mistral_api_key"] = convert_to_secret_str(
values, "mistral_api_key", "MISTRAL_API_KEY", default="" get_from_dict_or_env(
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
)
) )
values["client"] = MistralClient( values["client"] = MistralClient(
api_key=values["mistral_api_key"], api_key=values["mistral_api_key"].get_secret_value(),
endpoint=values["endpoint"], endpoint=values["endpoint"],
max_retries=values["max_retries"], max_retries=values["max_retries"],
timeout=values["timeout"], timeout=values["timeout"],
) )
values["async_client"] = MistralAsyncClient(
api_key=values["mistral_api_key"].get_secret_value(),
endpoint=values["endpoint"],
max_retries=values["max_retries"],
timeout=values["timeout"],
max_concurrent_requests=values["max_concurrent_requests"],
)
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]") raise ValueError("temperature must be in the range [0.0, 1.0]")