mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-01 10:54:15 +00:00
mistralai[patch]: persist async client (#15786)
This commit is contained in:
parent
3e0cd11f51
commit
323941a90a
@ -42,8 +42,8 @@ from langchain_core.outputs import (
|
|||||||
ChatGenerationChunk,
|
ChatGenerationChunk,
|
||||||
ChatResult,
|
ChatResult,
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import root_validator
|
from langchain_core.pydantic_v1 import SecretStr, root_validator
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
# TODO: Remove 'type: ignore' once mistralai has stubs or py.typed marker.
|
||||||
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
|
from mistralai.async_client import MistralAsyncClient # type: ignore[import]
|
||||||
@ -111,18 +111,11 @@ async def acompletion_with_retry(
|
|||||||
|
|
||||||
@retry_decorator
|
@retry_decorator
|
||||||
async def _completion_with_retry(**kwargs: Any) -> Any:
|
async def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
client = MistralAsyncClient(
|
|
||||||
api_key=llm.mistral_api_key,
|
|
||||||
endpoint=llm.endpoint,
|
|
||||||
max_retries=llm.max_retries,
|
|
||||||
timeout=llm.timeout,
|
|
||||||
max_concurrent_requests=llm.max_concurrent_requests,
|
|
||||||
)
|
|
||||||
stream = kwargs.pop("stream", False)
|
stream = kwargs.pop("stream", False)
|
||||||
if stream:
|
if stream:
|
||||||
return client.chat_stream(**kwargs)
|
return llm.async_client.chat_stream(**kwargs)
|
||||||
else:
|
else:
|
||||||
return await client.chat(**kwargs)
|
return await llm.async_client.chat(**kwargs)
|
||||||
|
|
||||||
return await _completion_with_retry(**kwargs)
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
@ -163,8 +156,9 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
class ChatMistralAI(BaseChatModel):
|
class ChatMistralAI(BaseChatModel):
|
||||||
"""A chat model that uses the MistralAI API."""
|
"""A chat model that uses the MistralAI API."""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: MistralClient = None #: :meta private:
|
||||||
mistral_api_key: Optional[str] = None
|
async_client: MistralAsyncClient = None #: :meta private:
|
||||||
|
mistral_api_key: Optional[SecretStr] = None
|
||||||
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
endpoint: str = DEFAULT_MISTRAL_ENDPOINT
|
||||||
max_retries: int = 5
|
max_retries: int = 5
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
@ -224,15 +218,24 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
"Please install it with `pip install mistralai`"
|
"Please install it with `pip install mistralai`"
|
||||||
)
|
)
|
||||||
|
|
||||||
values["mistral_api_key"] = get_from_dict_or_env(
|
values["mistral_api_key"] = convert_to_secret_str(
|
||||||
|
get_from_dict_or_env(
|
||||||
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
values, "mistral_api_key", "MISTRAL_API_KEY", default=""
|
||||||
)
|
)
|
||||||
|
)
|
||||||
values["client"] = MistralClient(
|
values["client"] = MistralClient(
|
||||||
api_key=values["mistral_api_key"],
|
api_key=values["mistral_api_key"].get_secret_value(),
|
||||||
endpoint=values["endpoint"],
|
endpoint=values["endpoint"],
|
||||||
max_retries=values["max_retries"],
|
max_retries=values["max_retries"],
|
||||||
timeout=values["timeout"],
|
timeout=values["timeout"],
|
||||||
)
|
)
|
||||||
|
values["async_client"] = MistralAsyncClient(
|
||||||
|
api_key=values["mistral_api_key"].get_secret_value(),
|
||||||
|
endpoint=values["endpoint"],
|
||||||
|
max_retries=values["max_retries"],
|
||||||
|
timeout=values["timeout"],
|
||||||
|
max_concurrent_requests=values["max_concurrent_requests"],
|
||||||
|
)
|
||||||
|
|
||||||
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
|
||||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
||||||
|
Loading…
Reference in New Issue
Block a user