mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
@@ -13,8 +13,6 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -93,9 +91,7 @@ global_ssl_context = ssl.create_default_context(cafile=certifi.where())
|
||||
|
||||
def _create_retry_decorator(
|
||||
llm: ChatMistralAI,
|
||||
run_manager: Optional[
|
||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||
] = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | CallbackManagerForLLMRun | None = None,
|
||||
) -> Callable[[Any], Any]:
|
||||
"""Return a tenacity retry decorator, preconfigured to handle exceptions."""
|
||||
errors = [httpx.RequestError, httpx.StreamError]
|
||||
@@ -211,7 +207,7 @@ async def _aiter_sse(
|
||||
|
||||
async def acompletion_with_retry(
|
||||
llm: ChatMistralAI,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the async completion call."""
|
||||
@@ -397,23 +393,23 @@ class ChatMistralAI(BaseChatModel):
|
||||
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:
|
||||
default=None, exclude=True
|
||||
) #: :meta private:
|
||||
mistral_api_key: Optional[SecretStr] = Field(
|
||||
mistral_api_key: SecretStr | None = Field(
|
||||
alias="api_key",
|
||||
default_factory=secret_from_env("MISTRAL_API_KEY", default=None),
|
||||
)
|
||||
endpoint: Optional[str] = Field(default=None, alias="base_url")
|
||||
endpoint: str | None = Field(default=None, alias="base_url")
|
||||
max_retries: int = 5
|
||||
timeout: int = 120
|
||||
max_concurrent_requests: int = 64
|
||||
model: str = Field(default="mistral-small", alias="model_name")
|
||||
temperature: float = 0.7
|
||||
max_tokens: Optional[int] = None
|
||||
max_tokens: int | None = None
|
||||
top_p: float = 1
|
||||
"""Decode using nucleus sampling: consider the smallest set of tokens whose
|
||||
probability sum is at least ``top_p``. Must be in the closed interval
|
||||
``[0.0, 1.0]``."""
|
||||
random_seed: Optional[int] = None
|
||||
safe_mode: Optional[bool] = None
|
||||
random_seed: int | None = None
|
||||
safe_mode: bool | None = None
|
||||
streaming: bool = False
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any invocation parameters not explicitly specified."""
|
||||
@@ -445,7 +441,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return {k: v for k, v in defaults.items() if v is not None}
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||
self, stop: list[str] | None = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get standard params for tracing."""
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
@@ -467,7 +463,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return self._default_params
|
||||
|
||||
def completion_with_retry(
|
||||
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
|
||||
self, run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
|
||||
@@ -496,7 +492,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||
def _combine_llm_outputs(self, llm_outputs: list[dict | None]) -> dict:
|
||||
overall_token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
@@ -515,7 +511,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
def validate_environment(self) -> Self:
|
||||
"""Validate api key, python package exists, temperature, and top_p."""
|
||||
if isinstance(self.mistral_api_key, SecretStr):
|
||||
api_key_str: Optional[str] = self.mistral_api_key.get_secret_value()
|
||||
api_key_str: str | None = self.mistral_api_key.get_secret_value()
|
||||
else:
|
||||
api_key_str = self.mistral_api_key
|
||||
|
||||
@@ -563,9 +559,9 @@ class ChatMistralAI(BaseChatModel):
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
@@ -608,7 +604,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: list[BaseMessage], stop: Optional[list[str]]
|
||||
self, messages: list[BaseMessage], stop: list[str] | None
|
||||
) -> tuple[list[dict], dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None or "stop" in params:
|
||||
@@ -623,8 +619,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
def _stream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: CallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
@@ -649,8 +645,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
async def _astream(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||
@@ -675,9 +671,9 @@ class ChatMistralAI(BaseChatModel):
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
stop: Optional[list[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None, # noqa: FBT001
|
||||
stop: list[str] | None = None,
|
||||
run_manager: AsyncCallbackManagerForLLMRun | None = None,
|
||||
stream: bool | None = None, # noqa: FBT001
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
should_stream = stream if stream is not None else self.streaming
|
||||
@@ -696,8 +692,8 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051
|
||||
tools: Sequence[dict[str, Any] | type | Callable | BaseTool],
|
||||
tool_choice: dict | str | Literal["auto", "any"] | None = None, # noqa: PYI051
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, AIMessage]:
|
||||
"""Bind tool-like objects to this chat model.
|
||||
@@ -710,7 +706,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
`langchain_core.utils.function_calling.convert_to_openai_tool`.
|
||||
tool_choice: Which tool to require the model to call.
|
||||
Must be the name of the single provided function or
|
||||
``'auto'`` to automatically determine which function to call
|
||||
`'auto'` to automatically determine which function to call
|
||||
(if any), or a dict of the form:
|
||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||
kwargs: Any additional parameters are passed directly to
|
||||
@@ -738,14 +734,14 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
def with_structured_output(
|
||||
self,
|
||||
schema: Optional[Union[dict, type]] = None,
|
||||
schema: dict | type | None = None,
|
||||
*,
|
||||
method: Literal[
|
||||
"function_calling", "json_mode", "json_schema"
|
||||
] = "function_calling",
|
||||
include_raw: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
) -> Runnable[LanguageModelInput, dict | BaseModel]:
|
||||
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||
|
||||
Args:
|
||||
@@ -1085,7 +1081,7 @@ class ChatMistralAI(BaseChatModel):
|
||||
|
||||
|
||||
def _convert_to_openai_response_format(
|
||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||
schema: dict[str, Any] | type, *, strict: bool | None = None
|
||||
) -> dict:
|
||||
"""Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user