mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
mistralai[patch]: ruff fixes and rules (#31918)
* bump ruff deps * add more thorough ruff rules * fix said rules
This commit is contained in:
parent
ae210c1590
commit
cbb418b4bf
@ -94,8 +94,7 @@ def _create_retry_decorator(
|
|||||||
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
|
||||||
] = None,
|
] = None,
|
||||||
) -> Callable[[Any], Any]:
|
) -> Callable[[Any], Any]:
|
||||||
"""Returns a tenacity retry decorator, preconfigured to handle exceptions"""
|
"""Return a tenacity retry decorator, preconfigured to handle exceptions."""
|
||||||
|
|
||||||
errors = [httpx.RequestError, httpx.StreamError]
|
errors = [httpx.RequestError, httpx.StreamError]
|
||||||
return create_base_retry_decorator(
|
return create_base_retry_decorator(
|
||||||
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
|
||||||
@ -103,12 +102,12 @@ def _create_retry_decorator(
|
|||||||
|
|
||||||
|
|
||||||
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
|
def _is_valid_mistral_tool_call_id(tool_call_id: str) -> bool:
|
||||||
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9"""
|
"""Check if tool call ID is nine character string consisting of a-z, A-Z, 0-9."""
|
||||||
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
|
return bool(TOOL_CALL_ID_PATTERN.match(tool_call_id))
|
||||||
|
|
||||||
|
|
||||||
def _base62_encode(num: int) -> str:
|
def _base62_encode(num: int) -> str:
|
||||||
"""Encodes a number in base62 and ensures result is of a specified length."""
|
"""Encode a number in base62 and ensures result is of a specified length."""
|
||||||
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
base62 = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||||
if num == 0:
|
if num == 0:
|
||||||
return base62[0]
|
return base62[0]
|
||||||
@ -122,17 +121,15 @@ def _base62_encode(num: int) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
def _convert_tool_call_id_to_mistral_compatible(tool_call_id: str) -> str:
|
||||||
"""Convert a tool call ID to a Mistral-compatible format"""
|
"""Convert a tool call ID to a Mistral-compatible format."""
|
||||||
if _is_valid_mistral_tool_call_id(tool_call_id):
|
if _is_valid_mistral_tool_call_id(tool_call_id):
|
||||||
return tool_call_id
|
return tool_call_id
|
||||||
else:
|
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
|
||||||
hash_bytes = hashlib.sha256(tool_call_id.encode()).digest()
|
hash_int = int.from_bytes(hash_bytes, byteorder="big")
|
||||||
hash_int = int.from_bytes(hash_bytes, byteorder="big")
|
base62_str = _base62_encode(hash_int)
|
||||||
base62_str = _base62_encode(hash_int)
|
if len(base62_str) >= 9:
|
||||||
if len(base62_str) >= 9:
|
return base62_str[:9]
|
||||||
return base62_str[:9]
|
return base62_str.rjust(9, "0")
|
||||||
else:
|
|
||||||
return base62_str.rjust(9, "0")
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_mistral_chat_message_to_message(
|
def _convert_mistral_chat_message_to_message(
|
||||||
@ -140,7 +137,8 @@ def _convert_mistral_chat_message_to_message(
|
|||||||
) -> BaseMessage:
|
) -> BaseMessage:
|
||||||
role = _message["role"]
|
role = _message["role"]
|
||||||
if role != "assistant":
|
if role != "assistant":
|
||||||
raise ValueError(f"Expected role to be 'assistant', got {role}")
|
msg = f"Expected role to be 'assistant', got {role}"
|
||||||
|
raise ValueError(msg)
|
||||||
content = cast(str, _message["content"])
|
content = cast(str, _message["content"])
|
||||||
|
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
@ -170,9 +168,12 @@ def _raise_on_error(response: httpx.Response) -> None:
|
|||||||
"""Raise an error if the response is an error."""
|
"""Raise an error if the response is an error."""
|
||||||
if httpx.codes.is_error(response.status_code):
|
if httpx.codes.is_error(response.status_code):
|
||||||
error_message = response.read().decode("utf-8")
|
error_message = response.read().decode("utf-8")
|
||||||
raise httpx.HTTPStatusError(
|
msg = (
|
||||||
f"Error response {response.status_code} "
|
f"Error response {response.status_code} "
|
||||||
f"while fetching {response.url}: {error_message}",
|
f"while fetching {response.url}: {error_message}"
|
||||||
|
)
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
msg,
|
||||||
request=response.request,
|
request=response.request,
|
||||||
response=response,
|
response=response,
|
||||||
)
|
)
|
||||||
@ -182,9 +183,12 @@ async def _araise_on_error(response: httpx.Response) -> None:
|
|||||||
"""Raise an error if the response is an error."""
|
"""Raise an error if the response is an error."""
|
||||||
if httpx.codes.is_error(response.status_code):
|
if httpx.codes.is_error(response.status_code):
|
||||||
error_message = (await response.aread()).decode("utf-8")
|
error_message = (await response.aread()).decode("utf-8")
|
||||||
raise httpx.HTTPStatusError(
|
msg = (
|
||||||
f"Error response {response.status_code} "
|
f"Error response {response.status_code} "
|
||||||
f"while fetching {response.url}: {error_message}",
|
f"while fetching {response.url}: {error_message}"
|
||||||
|
)
|
||||||
|
raise httpx.HTTPStatusError(
|
||||||
|
msg,
|
||||||
request=response.request,
|
request=response.request,
|
||||||
response=response,
|
response=response,
|
||||||
)
|
)
|
||||||
@ -220,10 +224,9 @@ async def acompletion_with_retry(
|
|||||||
llm.async_client, "POST", "/chat/completions", json=kwargs
|
llm.async_client, "POST", "/chat/completions", json=kwargs
|
||||||
)
|
)
|
||||||
return _aiter_sse(event_source)
|
return _aiter_sse(event_source)
|
||||||
else:
|
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
||||||
response = await llm.async_client.post(url="/chat/completions", json=kwargs)
|
await _araise_on_error(response)
|
||||||
await _araise_on_error(response)
|
return response.json()
|
||||||
return response.json()
|
|
||||||
|
|
||||||
return await _completion_with_retry(**kwargs)
|
return await _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
@ -237,7 +240,7 @@ def _convert_chunk_to_message_chunk(
|
|||||||
content = _delta.get("content") or ""
|
content = _delta.get("content") or ""
|
||||||
if role == "user" or default_class == HumanMessageChunk:
|
if role == "user" or default_class == HumanMessageChunk:
|
||||||
return HumanMessageChunk(content=content)
|
return HumanMessageChunk(content=content)
|
||||||
elif role == "assistant" or default_class == AIMessageChunk:
|
if role == "assistant" or default_class == AIMessageChunk:
|
||||||
additional_kwargs: dict = {}
|
additional_kwargs: dict = {}
|
||||||
response_metadata = {}
|
response_metadata = {}
|
||||||
if raw_tool_calls := _delta.get("tool_calls"):
|
if raw_tool_calls := _delta.get("tool_calls"):
|
||||||
@ -281,12 +284,11 @@ def _convert_chunk_to_message_chunk(
|
|||||||
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
usage_metadata=usage_metadata, # type: ignore[arg-type]
|
||||||
response_metadata=response_metadata,
|
response_metadata=response_metadata,
|
||||||
)
|
)
|
||||||
elif role == "system" or default_class == SystemMessageChunk:
|
if role == "system" or default_class == SystemMessageChunk:
|
||||||
return SystemMessageChunk(content=content)
|
return SystemMessageChunk(content=content)
|
||||||
elif role or default_class == ChatMessageChunk:
|
if role or default_class == ChatMessageChunk:
|
||||||
return ChatMessageChunk(content=content, role=role)
|
return ChatMessageChunk(content=content, role=role)
|
||||||
else:
|
return default_class(content=content) # type: ignore[call-arg]
|
||||||
return default_class(content=content) # type: ignore[call-arg]
|
|
||||||
|
|
||||||
|
|
||||||
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
def _format_tool_call_for_mistral(tool_call: ToolCall) -> dict:
|
||||||
@ -321,18 +323,24 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
message: BaseMessage,
|
message: BaseMessage,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if isinstance(message, ChatMessage):
|
if isinstance(message, ChatMessage):
|
||||||
return dict(role=message.role, content=message.content)
|
return {"role": message.role, "content": message.content}
|
||||||
elif isinstance(message, HumanMessage):
|
if isinstance(message, HumanMessage):
|
||||||
return dict(role="user", content=message.content)
|
return {"role": "user", "content": message.content}
|
||||||
elif isinstance(message, AIMessage):
|
if isinstance(message, AIMessage):
|
||||||
message_dict: dict[str, Any] = {"role": "assistant"}
|
message_dict: dict[str, Any] = {"role": "assistant"}
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
if message.tool_calls or message.invalid_tool_calls:
|
if message.tool_calls or message.invalid_tool_calls:
|
||||||
for tool_call in message.tool_calls:
|
for tool_call in message.tool_calls:
|
||||||
tool_calls.append(_format_tool_call_for_mistral(tool_call))
|
tool_calls.extend(
|
||||||
|
[
|
||||||
|
_format_tool_call_for_mistral(tool_call)
|
||||||
|
for tool_call in message.tool_calls
|
||||||
|
]
|
||||||
|
)
|
||||||
for invalid_tool_call in message.invalid_tool_calls:
|
for invalid_tool_call in message.invalid_tool_calls:
|
||||||
tool_calls.append(
|
tool_calls.extend(
|
||||||
_format_invalid_tool_call_for_mistral(invalid_tool_call)
|
_format_invalid_tool_call_for_mistral(invalid_tool_call)
|
||||||
|
for invalid_tool_call in message.invalid_tool_calls
|
||||||
)
|
)
|
||||||
elif "tool_calls" in message.additional_kwargs:
|
elif "tool_calls" in message.additional_kwargs:
|
||||||
for tc in message.additional_kwargs["tool_calls"]:
|
for tc in message.additional_kwargs["tool_calls"]:
|
||||||
@ -359,9 +367,9 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
if "prefix" in message.additional_kwargs:
|
if "prefix" in message.additional_kwargs:
|
||||||
message_dict["prefix"] = message.additional_kwargs["prefix"]
|
message_dict["prefix"] = message.additional_kwargs["prefix"]
|
||||||
return message_dict
|
return message_dict
|
||||||
elif isinstance(message, SystemMessage):
|
if isinstance(message, SystemMessage):
|
||||||
return dict(role="system", content=message.content)
|
return {"role": "system", "content": message.content}
|
||||||
elif isinstance(message, ToolMessage):
|
if isinstance(message, ToolMessage):
|
||||||
return {
|
return {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"content": message.content,
|
"content": message.content,
|
||||||
@ -370,8 +378,8 @@ def _convert_message_to_mistral_chat_message(
|
|||||||
message.tool_call_id
|
message.tool_call_id
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
else:
|
msg = f"Got unknown type {message}"
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
|
||||||
class ChatMistralAI(BaseChatModel):
|
class ChatMistralAI(BaseChatModel):
|
||||||
@ -380,10 +388,10 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
# The type for client and async_client is ignored because the type is not
|
# The type for client and async_client is ignored because the type is not
|
||||||
# an Optional after the model is initialized and the model_validator
|
# an Optional after the model is initialized and the model_validator
|
||||||
# is run.
|
# is run.
|
||||||
client: httpx.Client = Field( # type: ignore # : meta private:
|
client: httpx.Client = Field( # type: ignore[assignment] # : meta private:
|
||||||
default=None, exclude=True
|
default=None, exclude=True
|
||||||
)
|
)
|
||||||
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
|
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # : meta private:
|
||||||
default=None, exclude=True
|
default=None, exclude=True
|
||||||
) #: :meta private:
|
) #: :meta private:
|
||||||
mistral_api_key: Optional[SecretStr] = Field(
|
mistral_api_key: Optional[SecretStr] = Field(
|
||||||
@ -417,8 +425,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
"""Build extra kwargs from additional params that were passed in."""
|
||||||
all_required_field_names = get_pydantic_field_names(cls)
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
values = _build_model_kwargs(values, all_required_field_names)
|
return _build_model_kwargs(values, all_required_field_names)
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> dict[str, Any]:
|
def _default_params(self) -> dict[str, Any]:
|
||||||
@ -432,8 +439,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
"safe_prompt": self.safe_mode,
|
"safe_prompt": self.safe_mode,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
filtered = {k: v for k, v in defaults.items() if v is not None}
|
return {k: v for k, v in defaults.items() if v is not None}
|
||||||
return filtered
|
|
||||||
|
|
||||||
def _get_ls_params(
|
def _get_ls_params(
|
||||||
self, stop: Optional[list[str]] = None, **kwargs: Any
|
self, stop: Optional[list[str]] = None, **kwargs: Any
|
||||||
@ -481,13 +487,11 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
yield event.json()
|
yield event.json()
|
||||||
|
|
||||||
return iter_sse()
|
return iter_sse()
|
||||||
else:
|
response = self.client.post(url="/chat/completions", json=kwargs)
|
||||||
response = self.client.post(url="/chat/completions", json=kwargs)
|
_raise_on_error(response)
|
||||||
_raise_on_error(response)
|
return response.json()
|
||||||
return response.json()
|
|
||||||
|
|
||||||
rtn = _completion_with_retry(**kwargs)
|
return _completion_with_retry(**kwargs)
|
||||||
return rtn
|
|
||||||
|
|
||||||
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
|
||||||
overall_token_usage: dict = {}
|
overall_token_usage: dict = {}
|
||||||
@ -502,8 +506,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
overall_token_usage[k] += v
|
overall_token_usage[k] += v
|
||||||
else:
|
else:
|
||||||
overall_token_usage[k] = v
|
overall_token_usage[k] = v
|
||||||
combined = {"token_usage": overall_token_usage, "model_name": self.model}
|
return {"token_usage": overall_token_usage, "model_name": self.model}
|
||||||
return combined
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
@ -545,10 +548,12 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
if self.temperature is not None and not 0 <= self.temperature <= 1:
|
||||||
raise ValueError("temperature must be in the range [0.0, 1.0]")
|
msg = "temperature must be in the range [0.0, 1.0]"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
if self.top_p is not None and not 0 <= self.top_p <= 1:
|
||||||
raise ValueError("top_p must be in the range [0.0, 1.0]")
|
msg = "top_p must be in the range [0.0, 1.0]"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -557,7 +562,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None, # noqa: FBT001
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
@ -669,7 +674,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
stream: Optional[bool] = None,
|
stream: Optional[bool] = None, # noqa: FBT001
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
should_stream = stream if stream is not None else self.streaming
|
should_stream = stream if stream is not None else self.streaming
|
||||||
@ -689,7 +694,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
def bind_tools(
|
def bind_tools(
|
||||||
self,
|
self,
|
||||||
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
|
||||||
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None,
|
tool_choice: Optional[Union[dict, str, Literal["auto", "any"]]] = None, # noqa: PYI051
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
"""Bind tool-like objects to this chat model.
|
"""Bind tool-like objects to this chat model.
|
||||||
@ -707,15 +712,15 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
{"type": "function", "function": {"name": <<tool_name>>}}.
|
{"type": "function", "function": {"name": <<tool_name>>}}.
|
||||||
kwargs: Any additional parameters are passed directly to
|
kwargs: Any additional parameters are passed directly to
|
||||||
``self.bind(**kwargs)``.
|
``self.bind(**kwargs)``.
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
tool_names = []
|
tool_names = []
|
||||||
for tool in formatted_tools:
|
for tool in formatted_tools:
|
||||||
if "function" in tool and (name := tool["function"].get("name")):
|
if ("function" in tool and (name := tool["function"].get("name"))) or (
|
||||||
tool_names.append(name)
|
name := tool.get("name")
|
||||||
elif name := tool.get("name"):
|
):
|
||||||
tool_names.append(name)
|
tool_names.append(name)
|
||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
@ -738,7 +743,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
include_raw: bool = False,
|
include_raw: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||||
"""Model wrapper that returns outputs formatted to match the given schema.
|
r"""Model wrapper that returns outputs formatted to match the given schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
schema:
|
schema:
|
||||||
@ -785,6 +790,12 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
will be caught and returned as well. The final output is always a dict
|
will be caught and returned as well. The final output is always a dict
|
||||||
with keys "raw", "parsed", and "parsing_error".
|
with keys "raw", "parsed", and "parsing_error".
|
||||||
|
|
||||||
|
kwargs: Any additional parameters are passed directly to
|
||||||
|
``self.bind(**kwargs)``. This is useful for passing in
|
||||||
|
parameters such as ``tool_choice`` or ``tools`` to control
|
||||||
|
which tool the model should call, or to pass in parameters such as
|
||||||
|
``stop`` to control when the model should stop generating output.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
|
||||||
|
|
||||||
@ -968,14 +979,16 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
_ = kwargs.pop("strict", None)
|
_ = kwargs.pop("strict", None)
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Received unsupported arguments {kwargs}")
|
msg = f"Received unsupported arguments {kwargs}"
|
||||||
|
raise ValueError(msg)
|
||||||
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
is_pydantic_schema = isinstance(schema, type) and is_basemodel_subclass(schema)
|
||||||
if method == "function_calling":
|
if method == "function_calling":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is 'function_calling'. "
|
"schema must be specified when method is 'function_calling'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
|
# TODO: Update to pass in tool name as tool_choice if/when Mistral supports
|
||||||
# specifying a tool.
|
# specifying a tool.
|
||||||
llm = self.bind_tools(
|
llm = self.bind_tools(
|
||||||
@ -1014,10 +1027,11 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
elif method == "json_schema":
|
elif method == "json_schema":
|
||||||
if schema is None:
|
if schema is None:
|
||||||
raise ValueError(
|
msg = (
|
||||||
"schema must be specified when method is 'json_schema'. "
|
"schema must be specified when method is 'json_schema'. "
|
||||||
"Received None."
|
"Received None."
|
||||||
)
|
)
|
||||||
|
raise ValueError(msg)
|
||||||
response_format = _convert_to_openai_response_format(schema, strict=True)
|
response_format = _convert_to_openai_response_format(schema, strict=True)
|
||||||
llm = self.bind(
|
llm = self.bind(
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
@ -1041,8 +1055,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
[parser_none], exception_key="parsing_error"
|
[parser_none], exception_key="parsing_error"
|
||||||
)
|
)
|
||||||
return RunnableMap(raw=llm) | parser_with_fallback
|
return RunnableMap(raw=llm) | parser_with_fallback
|
||||||
else:
|
return llm | output_parser
|
||||||
return llm | output_parser
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> dict[str, Any]:
|
def _identifying_params(self) -> dict[str, Any]:
|
||||||
@ -1072,7 +1085,7 @@ class ChatMistralAI(BaseChatModel):
|
|||||||
def _convert_to_openai_response_format(
|
def _convert_to_openai_response_format(
|
||||||
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
schema: Union[dict[str, Any], type], *, strict: Optional[bool] = None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Same as in ChatOpenAI, but don't pass through Pydantic BaseModels."""
|
"""Perform same op as in ChatOpenAI, but do not pass through Pydantic BaseModels."""
|
||||||
if (
|
if (
|
||||||
isinstance(schema, dict)
|
isinstance(schema, dict)
|
||||||
and "json_schema" in schema
|
and "json_schema" in schema
|
||||||
|
@ -17,20 +17,20 @@ from pydantic import (
|
|||||||
model_validator,
|
model_validator,
|
||||||
)
|
)
|
||||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||||
from tokenizers import Tokenizer # type: ignore
|
from tokenizers import Tokenizer # type: ignore[import]
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_TOKENS = 16_000
|
MAX_TOKENS = 16_000
|
||||||
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
|
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
|
||||||
accepted by the embedding model for each document/chunk, but rather the maximum number
|
accepted by the embedding model for each document/chunk, but rather the maximum number
|
||||||
of tokens that can be sent in a single request to the Mistral API (across multiple
|
of tokens that can be sent in a single request to the Mistral API (across multiple
|
||||||
documents/chunks)"""
|
documents/chunks)"""
|
||||||
|
|
||||||
|
|
||||||
class DummyTokenizer:
|
class DummyTokenizer:
|
||||||
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
|
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def encode_batch(texts: list[str]) -> list[list[str]]:
|
def encode_batch(texts: list[str]) -> list[list[str]]:
|
||||||
@ -126,9 +126,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
# The type for client and async_client is ignored because the type is not
|
# The type for client and async_client is ignored because the type is not
|
||||||
# an Optional after the model is initialized and the model_validator
|
# an Optional after the model is initialized and the model_validator
|
||||||
# is run.
|
# is run.
|
||||||
client: httpx.Client = Field(default=None) # type: ignore # : :meta private:
|
client: httpx.Client = Field(default=None) # type: ignore[assignment] # :meta private:
|
||||||
|
|
||||||
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
|
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # :meta private:
|
||||||
default=None
|
default=None
|
||||||
)
|
)
|
||||||
mistral_api_key: SecretStr = Field(
|
mistral_api_key: SecretStr = Field(
|
||||||
@ -153,7 +153,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_environment(self) -> Self:
|
def validate_environment(self) -> Self:
|
||||||
"""Validate configuration."""
|
"""Validate configuration."""
|
||||||
|
|
||||||
api_key_str = self.mistral_api_key.get_secret_value()
|
api_key_str = self.mistral_api_key.get_secret_value()
|
||||||
# todo: handle retries
|
# todo: handle retries
|
||||||
if not self.client:
|
if not self.client:
|
||||||
@ -187,14 +186,14 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
"Could not download mistral tokenizer from Huggingface for "
|
"Could not download mistral tokenizer from Huggingface for "
|
||||||
"calculating batch sizes. Set a Huggingface token via the "
|
"calculating batch sizes. Set a Huggingface token via the "
|
||||||
"HF_TOKEN environment variable to download the real tokenizer. "
|
"HF_TOKEN environment variable to download the real tokenizer. "
|
||||||
"Falling back to a dummy tokenizer that uses `len()`."
|
"Falling back to a dummy tokenizer that uses `len()`.",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
self.tokenizer = DummyTokenizer()
|
self.tokenizer = DummyTokenizer()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
|
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
|
||||||
"""Split a list of texts into batches of less than 16k tokens for Mistral
|
"""Split list of texts into batches of less than 16k tokens for Mistral API."""
|
||||||
API."""
|
|
||||||
batch: list[str] = []
|
batch: list[str] = []
|
||||||
batch_tokens = 0
|
batch_tokens = 0
|
||||||
|
|
||||||
@ -224,6 +223,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
batch_responses = []
|
batch_responses = []
|
||||||
@ -238,16 +238,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
def _embed_batch(batch: list[str]) -> Response:
|
def _embed_batch(batch: list[str]) -> Response:
|
||||||
response = self.client.post(
|
response = self.client.post(
|
||||||
url="/embeddings",
|
url="/embeddings",
|
||||||
json=dict(
|
json={
|
||||||
model=self.model,
|
"model": self.model,
|
||||||
input=batch,
|
"input": batch,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
for batch in self._get_batches(texts):
|
batch_responses = [
|
||||||
batch_responses.append(_embed_batch(batch))
|
_embed_batch(batch) for batch in self._get_batches(texts)
|
||||||
|
]
|
||||||
return [
|
return [
|
||||||
list(map(float, embedding_obj["embedding"]))
|
list(map(float, embedding_obj["embedding"]))
|
||||||
for response in batch_responses
|
for response in batch_responses
|
||||||
@ -265,16 +266,17 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of embeddings, one for each text.
|
List of embeddings, one for each text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
batch_responses = await asyncio.gather(
|
batch_responses = await asyncio.gather(
|
||||||
*[
|
*[
|
||||||
self.async_client.post(
|
self.async_client.post(
|
||||||
url="/embeddings",
|
url="/embeddings",
|
||||||
json=dict(
|
json={
|
||||||
model=self.model,
|
"model": self.model,
|
||||||
input=batch,
|
"input": batch,
|
||||||
),
|
},
|
||||||
)
|
)
|
||||||
for batch in self._get_batches(texts)
|
for batch in self._get_batches(texts)
|
||||||
]
|
]
|
||||||
@ -296,6 +298,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.embed_documents([text])[0]
|
return self.embed_documents([text])[0]
|
||||||
|
|
||||||
@ -307,5 +310,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Embedding for the text.
|
Embedding for the text.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return (await self.aembed_documents([text]))[0]
|
return (await self.aembed_documents([text]))[0]
|
||||||
|
@ -48,8 +48,62 @@ disallow_untyped_defs = "True"
|
|||||||
target-version = "py39"
|
target-version = "py39"
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = ["E", "F", "I", "T201", "UP", "S"]
|
select = [
|
||||||
ignore = [ "UP007", ]
|
"A", # flake8-builtins
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"ASYNC", # flake8-async
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"COM", # flake8-commas
|
||||||
|
"D", # pydocstyle
|
||||||
|
"DOC", # pydoclint
|
||||||
|
"E", # pycodestyle error
|
||||||
|
"EM", # flake8-errmsg
|
||||||
|
"F", # pyflakes
|
||||||
|
"FA", # flake8-future-annotations
|
||||||
|
"FBT", # flake8-boolean-trap
|
||||||
|
"FLY", # flake8-flynt
|
||||||
|
"I", # isort
|
||||||
|
"ICN", # flake8-import-conventions
|
||||||
|
"INT", # flake8-gettext
|
||||||
|
"ISC", # isort-comprehensions
|
||||||
|
"PGH", # pygrep-hooks
|
||||||
|
"PIE", # flake8-pie
|
||||||
|
"PERF", # flake8-perf
|
||||||
|
"PYI", # flake8-pyi
|
||||||
|
"Q", # flake8-quotes
|
||||||
|
"RET", # flake8-return
|
||||||
|
"RSE", # flake8-rst-docstrings
|
||||||
|
"RUF", # ruff
|
||||||
|
"S", # flake8-bandit
|
||||||
|
"SLF", # flake8-self
|
||||||
|
"SLOT", # flake8-slots
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"T10", # flake8-debugger
|
||||||
|
"T20", # flake8-print
|
||||||
|
"TID", # flake8-tidy-imports
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"W", # pycodestyle warning
|
||||||
|
"YTT", # flake8-2020
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"D100", # pydocstyle: Missing docstring in public module
|
||||||
|
"D101", # pydocstyle: Missing docstring in public class
|
||||||
|
"D102", # pydocstyle: Missing docstring in public method
|
||||||
|
"D103", # pydocstyle: Missing docstring in public function
|
||||||
|
"D104", # pydocstyle: Missing docstring in public package
|
||||||
|
"D105", # pydocstyle: Missing docstring in magic method
|
||||||
|
"D107", # pydocstyle: Missing docstring in __init__
|
||||||
|
"D203", # Messes with the formatter
|
||||||
|
"D407", # pydocstyle: Missing-dashed-underline-after-section
|
||||||
|
"COM812", # Messes with the formatter
|
||||||
|
"ISC001", # Messes with the formatter
|
||||||
|
"PERF203", # Rarely useful
|
||||||
|
"S112", # Rarely useful
|
||||||
|
"RUF012", # Doesn't play well with Pydantic
|
||||||
|
"SLF001", # Private member access
|
||||||
|
"UP007", # pyupgrade: non-pep604-annotation-union
|
||||||
|
"UP045", # pyupgrade: non-pep604-annotation-optional
|
||||||
|
]
|
||||||
|
|
||||||
[tool.coverage.run]
|
[tool.coverage.run]
|
||||||
omit = ["tests/*"]
|
omit = ["tests/*"]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Test ChatMistral chat model."""
|
"""Test ChatMistral chat model."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
@ -43,11 +45,12 @@ async def test_astream() -> None:
|
|||||||
if token.response_metadata:
|
if token.response_metadata:
|
||||||
chunks_with_response_metadata += 1
|
chunks_with_response_metadata += 1
|
||||||
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
||||||
raise AssertionError(
|
msg = (
|
||||||
"Expected exactly one chunk with token counts or response_metadata. "
|
"Expected exactly one chunk with token counts or response_metadata. "
|
||||||
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
|
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
|
||||||
"this is behaving properly."
|
"this is behaving properly."
|
||||||
)
|
)
|
||||||
|
raise AssertionError(msg)
|
||||||
assert isinstance(full, AIMessageChunk)
|
assert isinstance(full, AIMessageChunk)
|
||||||
assert full.usage_metadata is not None
|
assert full.usage_metadata is not None
|
||||||
assert full.usage_metadata["input_tokens"] > 0
|
assert full.usage_metadata["input_tokens"] > 0
|
||||||
@ -61,7 +64,7 @@ async def test_astream() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_abatch() -> None:
|
async def test_abatch() -> None:
|
||||||
"""Test streaming tokens from ChatMistralAI"""
|
"""Test streaming tokens from ChatMistralAI."""
|
||||||
llm = ChatMistralAI()
|
llm = ChatMistralAI()
|
||||||
|
|
||||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
@ -70,7 +73,7 @@ async def test_abatch() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_abatch_tags() -> None:
|
async def test_abatch_tags() -> None:
|
||||||
"""Test batch tokens from ChatMistralAI"""
|
"""Test batch tokens from ChatMistralAI."""
|
||||||
llm = ChatMistralAI()
|
llm = ChatMistralAI()
|
||||||
|
|
||||||
result = await llm.abatch(
|
result = await llm.abatch(
|
||||||
@ -81,7 +84,7 @@ async def test_abatch_tags() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_batch() -> None:
|
def test_batch() -> None:
|
||||||
"""Test batch tokens from ChatMistralAI"""
|
"""Test batch tokens from ChatMistralAI."""
|
||||||
llm = ChatMistralAI()
|
llm = ChatMistralAI()
|
||||||
|
|
||||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||||
@ -90,7 +93,7 @@ def test_batch() -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_ainvoke() -> None:
|
async def test_ainvoke() -> None:
|
||||||
"""Test invoke tokens from ChatMistralAI"""
|
"""Test invoke tokens from ChatMistralAI."""
|
||||||
llm = ChatMistralAI()
|
llm = ChatMistralAI()
|
||||||
|
|
||||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
@ -99,10 +102,10 @@ async def test_ainvoke() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_invoke() -> None:
|
def test_invoke() -> None:
|
||||||
"""Test invoke tokens from ChatMistralAI"""
|
"""Test invoke tokens from ChatMistralAI."""
|
||||||
llm = ChatMistralAI()
|
llm = ChatMistralAI()
|
||||||
|
|
||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
|
||||||
@ -178,13 +181,11 @@ def test_streaming_structured_output() -> None:
|
|||||||
|
|
||||||
structured_llm = llm.with_structured_output(Person)
|
structured_llm = llm.with_structured_output(Person)
|
||||||
strm = structured_llm.stream("Erick, 27 years old")
|
strm = structured_llm.stream("Erick, 27 years old")
|
||||||
chunk_num = 0
|
for chunk_num, chunk in enumerate(strm):
|
||||||
for chunk in strm:
|
|
||||||
assert chunk_num == 0, "should only have one chunk with model"
|
assert chunk_num == 0, "should only have one chunk with model"
|
||||||
assert isinstance(chunk, Person)
|
assert isinstance(chunk, Person)
|
||||||
assert chunk.name == "Erick"
|
assert chunk.name == "Erick"
|
||||||
assert chunk.age == 27
|
assert chunk.age == 27
|
||||||
chunk_num += 1
|
|
||||||
|
|
||||||
|
|
||||||
class Book(BaseModel):
|
class Book(BaseModel):
|
||||||
@ -201,7 +202,7 @@ def _check_parsed_result(result: Any, schema: Any) -> None:
|
|||||||
if schema == Book:
|
if schema == Book:
|
||||||
assert isinstance(result, Book)
|
assert isinstance(result, Book)
|
||||||
else:
|
else:
|
||||||
assert all(key in ["name", "authors"] for key in result.keys())
|
assert all(key in ["name", "authors"] for key in result)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
||||||
|
@ -4,4 +4,3 @@ import pytest
|
|||||||
@pytest.mark.compile
|
@pytest.mark.compile
|
||||||
def test_placeholder() -> None:
|
def test_placeholder() -> None:
|
||||||
"""Used for compiling integration tests without running any real tests."""
|
"""Used for compiling integration tests without running any real tests."""
|
||||||
pass
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Test MistralAI Embedding"""
|
"""Test MistralAI Embedding."""
|
||||||
|
|
||||||
from langchain_mistralai import MistralAIEmbeddings
|
from langchain_mistralai import MistralAIEmbeddings
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Standard LangChain interface tests"""
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]
|
||||||
|
@ -84,23 +84,23 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
SystemMessage(content="Hello"),
|
SystemMessage(content="Hello"),
|
||||||
dict(role="system", content="Hello"),
|
{"role": "system", "content": "Hello"},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
HumanMessage(content="Hello"),
|
HumanMessage(content="Hello"),
|
||||||
dict(role="user", content="Hello"),
|
{"role": "user", "content": "Hello"},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
AIMessage(content="Hello"),
|
AIMessage(content="Hello"),
|
||||||
dict(role="assistant", content="Hello"),
|
{"role": "assistant", "content": "Hello"},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
AIMessage(content="{", additional_kwargs={"prefix": True}),
|
AIMessage(content="{", additional_kwargs={"prefix": True}),
|
||||||
dict(role="assistant", content="{", prefix=True),
|
{"role": "assistant", "content": "{", "prefix": True},
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ChatMessage(role="assistant", content="Hello"),
|
ChatMessage(role="assistant", content="Hello"),
|
||||||
dict(role="assistant", content="Hello"),
|
{"role": "assistant", "content": "Hello"},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -112,17 +112,17 @@ def test_convert_message_to_mistral_chat_message(
|
|||||||
|
|
||||||
|
|
||||||
def _make_completion_response_from_token(token: str) -> dict:
|
def _make_completion_response_from_token(token: str) -> dict:
|
||||||
return dict(
|
return {
|
||||||
id="abc123",
|
"id": "abc123",
|
||||||
model="fake_model",
|
"model": "fake_model",
|
||||||
choices=[
|
"choices": [
|
||||||
dict(
|
{
|
||||||
index=0,
|
"index": 0,
|
||||||
delta=dict(content=token),
|
"delta": {"content": token},
|
||||||
finish_reason=None,
|
"finish_reason": None,
|
||||||
)
|
}
|
||||||
],
|
],
|
||||||
)
|
}
|
||||||
|
|
||||||
|
|
||||||
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
||||||
@ -275,8 +275,7 @@ def test_extra_kwargs() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_retry_with_failure_then_success() -> None:
|
def test_retry_with_failure_then_success() -> None:
|
||||||
"""Test that retry mechanism works correctly when
|
"""Test retry mechanism works correctly when fiest request fails, second succeed."""
|
||||||
first request fails and second succeeds."""
|
|
||||||
# Create a real ChatMistralAI instance
|
# Create a real ChatMistralAI instance
|
||||||
chat = ChatMistralAI(max_retries=3)
|
chat = ChatMistralAI(max_retries=3)
|
||||||
|
|
||||||
@ -289,7 +288,8 @@ def test_retry_with_failure_then_success() -> None:
|
|||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
||||||
if call_count == 1:
|
if call_count == 1:
|
||||||
raise httpx.RequestError("Connection error", request=MagicMock())
|
msg = "Connection error"
|
||||||
|
raise httpx.RequestError(msg, request=MagicMock())
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Standard LangChain interface tests"""
|
"""Standard LangChain interface tests."""
|
||||||
|
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]
|
||||||
|
@ -379,7 +379,7 @@ dev = [
|
|||||||
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
|
{ name = "jupyter", specifier = ">=1.0.0,<2.0.0" },
|
||||||
{ name = "setuptools", specifier = ">=67.6.1,<68.0.0" },
|
{ name = "setuptools", specifier = ">=67.6.1,<68.0.0" },
|
||||||
]
|
]
|
||||||
lint = [{ name = "ruff", specifier = ">=0.11.2,<0.12.0" }]
|
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
|
||||||
test = [
|
test = [
|
||||||
{ name = "blockbuster", specifier = "~=1.5.18" },
|
{ name = "blockbuster", specifier = "~=1.5.18" },
|
||||||
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
{ name = "freezegun", specifier = ">=1.2.2,<2.0.0" },
|
||||||
@ -452,7 +452,7 @@ requires-dist = [
|
|||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
codespell = [{ name = "codespell", specifier = ">=2.2.0,<3.0.0" }]
|
||||||
dev = [{ name = "langchain-core", editable = "../../core" }]
|
dev = [{ name = "langchain-core", editable = "../../core" }]
|
||||||
lint = [{ name = "ruff", specifier = ">=0.5,<1.0" }]
|
lint = [{ name = "ruff", specifier = ">=0.12.2,<0.13" }]
|
||||||
test = [
|
test = [
|
||||||
{ name = "langchain-core", editable = "../../core" },
|
{ name = "langchain-core", editable = "../../core" },
|
||||||
{ name = "langchain-tests", editable = "../../standard-tests" },
|
{ name = "langchain-tests", editable = "../../standard-tests" },
|
||||||
@ -1370,27 +1370,27 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ruff"
|
name = "ruff"
|
||||||
version = "0.9.4"
|
version = "0.12.2"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/17/529e78f49fc6f8076f50d985edd9a2cf011d1dbadb1cdeacc1d12afc1d26/ruff-0.9.4.tar.gz", hash = "sha256:6907ee3529244bb0ed066683e075f09285b38dd5b4039370df6ff06041ca19e7", size = 3599458, upload-time = "2025-01-30T18:09:51.03Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/6c/3d/d9a195676f25d00dbfcf3cf95fdd4c685c497fcfa7e862a44ac5e4e96480/ruff-0.12.2.tar.gz", hash = "sha256:d7b4f55cd6f325cb7621244f19c873c565a08aff5a4ba9c69aa7355f3f7afd3e", size = 4432239, upload-time = "2025-07-03T16:40:19.566Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/b6/f8/3fafb7804d82e0699a122101b5bee5f0d6e17c3a806dcbc527bb7d3f5b7a/ruff-0.9.4-py3-none-linux_armv6l.whl", hash = "sha256:64e73d25b954f71ff100bb70f39f1ee09e880728efb4250c632ceed4e4cdf706", size = 11668400, upload-time = "2025-01-30T18:08:46.508Z" },
|
{ url = "https://files.pythonhosted.org/packages/74/b6/2098d0126d2d3318fd5bec3ad40d06c25d377d95749f7a0c5af17129b3b1/ruff-0.12.2-py3-none-linux_armv6l.whl", hash = "sha256:093ea2b221df1d2b8e7ad92fc6ffdca40a2cb10d8564477a987b44fd4008a7be", size = 10369761, upload-time = "2025-07-03T16:39:38.847Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/2e/a6/2efa772d335da48a70ab2c6bb41a096c8517ca43c086ea672d51079e3d1f/ruff-0.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6ce6743ed64d9afab4fafeaea70d3631b4d4b28b592db21a5c2d1f0ef52934bf", size = 11628395, upload-time = "2025-01-30T18:08:50.87Z" },
|
{ url = "https://files.pythonhosted.org/packages/b1/4b/5da0142033dbe155dc598cfb99262d8ee2449d76920ea92c4eeb9547c208/ruff-0.12.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:09e4cf27cc10f96b1708100fa851e0daf21767e9709e1649175355280e0d950e", size = 11155659, upload-time = "2025-07-03T16:39:42.294Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/dc/d7/cd822437561082f1c9d7225cc0d0fbb4bad117ad7ac3c41cd5d7f0fa948c/ruff-0.9.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:54499fb08408e32b57360f6f9de7157a5fec24ad79cb3f42ef2c3f3f728dfe2b", size = 11090052, upload-time = "2025-01-30T18:08:54.498Z" },
|
{ url = "https://files.pythonhosted.org/packages/3e/21/967b82550a503d7c5c5c127d11c935344b35e8c521f52915fc858fb3e473/ruff-0.12.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:8ae64755b22f4ff85e9c52d1f82644abd0b6b6b6deedceb74bd71f35c24044cc", size = 10537769, upload-time = "2025-07-03T16:39:44.75Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/9e/67/3660d58e893d470abb9a13f679223368ff1684a4ef40f254a0157f51b448/ruff-0.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37c892540108314a6f01f105040b5106aeb829fa5fb0561d2dcaf71485021137", size = 11882221, upload-time = "2025-01-30T18:08:57.784Z" },
|
{ url = "https://files.pythonhosted.org/packages/33/91/00cff7102e2ec71a4890fb7ba1803f2cdb122d82787c7d7cf8041fe8cbc1/ruff-0.12.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3eb3a6b2db4d6e2c77e682f0b988d4d61aff06860158fdb413118ca133d57922", size = 10717602, upload-time = "2025-07-03T16:39:47.652Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/79/d1/757559995c8ba5f14dfec4459ef2dd3fcea82ac43bc4e7c7bf47484180c0/ruff-0.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de9edf2ce4b9ddf43fd93e20ef635a900e25f622f87ed6e3047a664d0e8f810e", size = 11424862, upload-time = "2025-01-30T18:09:01.167Z" },
|
{ url = "https://files.pythonhosted.org/packages/9b/eb/928814daec4e1ba9115858adcda44a637fb9010618721937491e4e2283b8/ruff-0.12.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:73448de992d05517170fc37169cbca857dfeaeaa8c2b9be494d7bcb0d36c8f4b", size = 10198772, upload-time = "2025-07-03T16:39:49.641Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c0/96/7915a7c6877bb734caa6a2af424045baf6419f685632469643dbd8eb2958/ruff-0.9.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c90c32357c74f11deb7fbb065126d91771b207bf9bfaaee01277ca59b574ec", size = 12626735, upload-time = "2025-01-30T18:09:05.312Z" },
|
{ url = "https://files.pythonhosted.org/packages/50/fa/f15089bc20c40f4f72334f9145dde55ab2b680e51afb3b55422effbf2fb6/ruff-0.12.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b8b94317cbc2ae4a2771af641739f933934b03555e51515e6e021c64441532d", size = 11845173, upload-time = "2025-07-03T16:39:52.069Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/0e/cc/dadb9b35473d7cb17c7ffe4737b4377aeec519a446ee8514123ff4a26091/ruff-0.9.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:56acd6c694da3695a7461cc55775f3a409c3815ac467279dfa126061d84b314b", size = 13255976, upload-time = "2025-01-30T18:09:09.425Z" },
|
{ url = "https://files.pythonhosted.org/packages/43/9f/1f6f98f39f2b9302acc161a4a2187b1e3a97634fe918a8e731e591841cf4/ruff-0.12.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:45fc42c3bf1d30d2008023a0a9a0cfb06bf9835b147f11fe0679f21ae86d34b1", size = 12553002, upload-time = "2025-07-03T16:39:54.551Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/5f/c3/ad2dd59d3cabbc12df308cced780f9c14367f0321e7800ca0fe52849da4c/ruff-0.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0c93e7d47ed951b9394cf352d6695b31498e68fd5782d6cbc282425655f687a", size = 12752262, upload-time = "2025-01-30T18:09:13.112Z" },
|
{ url = "https://files.pythonhosted.org/packages/d8/70/08991ac46e38ddd231c8f4fd05ef189b1b94be8883e8c0c146a025c20a19/ruff-0.12.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ce48f675c394c37e958bf229fb5c1e843e20945a6d962cf3ea20b7a107dcd9f4", size = 12171330, upload-time = "2025-07-03T16:39:57.55Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c7/17/5f1971e54bd71604da6788efd84d66d789362b1105e17e5ccc53bba0289b/ruff-0.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d4c8772670aecf037d1bf7a07c39106574d143b26cfe5ed1787d2f31e800214", size = 14401648, upload-time = "2025-01-30T18:09:17.086Z" },
|
{ url = "https://files.pythonhosted.org/packages/88/a9/5a55266fec474acfd0a1c73285f19dd22461d95a538f29bba02edd07a5d9/ruff-0.12.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:793d8859445ea47591272021a81391350205a4af65a9392401f418a95dfb75c9", size = 11774717, upload-time = "2025-07-03T16:39:59.78Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/30/24/6200b13ea611b83260501b6955b764bb320e23b2b75884c60ee7d3f0b68e/ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc5f1d7afeda8d5d37660eeca6d389b142d7f2b5a1ab659d9214ebd0e025231", size = 12414702, upload-time = "2025-01-30T18:09:21.672Z" },
|
{ url = "https://files.pythonhosted.org/packages/87/e5/0c270e458fc73c46c0d0f7cf970bb14786e5fdb88c87b5e423a4bd65232b/ruff-0.12.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6932323db80484dda89153da3d8e58164d01d6da86857c79f1961934354992da", size = 11646659, upload-time = "2025-07-03T16:40:01.934Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/34/cb/f5d50d0c4ecdcc7670e348bd0b11878154bc4617f3fdd1e8ad5297c0d0ba/ruff-0.9.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faa935fc00ae854d8b638c16a5f1ce881bc3f67446957dd6f2af440a5fc8526b", size = 11859608, upload-time = "2025-01-30T18:09:25.663Z" },
|
{ url = "https://files.pythonhosted.org/packages/b7/b6/45ab96070c9752af37f0be364d849ed70e9ccede07675b0ec4e3ef76b63b/ruff-0.12.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6aa7e623a3a11538108f61e859ebf016c4f14a7e6e4eba1980190cacb57714ce", size = 10604012, upload-time = "2025-07-03T16:40:04.363Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/d6/f4/9c8499ae8426da48363bbb78d081b817b0f64a9305f9b7f87eab2a8fb2c1/ruff-0.9.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6c634fc6f5a0ceae1ab3e13c58183978185d131a29c425e4eaa9f40afe1e6d6", size = 11485702, upload-time = "2025-01-30T18:09:28.903Z" },
|
{ url = "https://files.pythonhosted.org/packages/86/91/26a6e6a424eb147cc7627eebae095cfa0b4b337a7c1c413c447c9ebb72fd/ruff-0.12.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2a4a20aeed74671b2def096bdf2eac610c7d8ffcbf4fb0e627c06947a1d7078d", size = 10176799, upload-time = "2025-07-03T16:40:06.514Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/18/59/30490e483e804ccaa8147dd78c52e44ff96e1c30b5a95d69a63163cdb15b/ruff-0.9.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:433dedf6ddfdec7f1ac7575ec1eb9844fa60c4c8c2f8887a070672b8d353d34c", size = 12067782, upload-time = "2025-01-30T18:09:32.371Z" },
|
{ url = "https://files.pythonhosted.org/packages/f5/0c/9f344583465a61c8918a7cda604226e77b2c548daf8ef7c2bfccf2b37200/ruff-0.12.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:71a4c550195612f486c9d1f2b045a600aeba851b298c667807ae933478fcef04", size = 11241507, upload-time = "2025-07-03T16:40:08.708Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/3d/8c/893fa9551760b2f8eb2a351b603e96f15af167ceaf27e27ad873570bc04c/ruff-0.9.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d612dbd0f3a919a8cc1d12037168bfa536862066808960e0cc901404b77968f0", size = 12483087, upload-time = "2025-01-30T18:09:36.124Z" },
|
{ url = "https://files.pythonhosted.org/packages/1c/b7/99c34ded8fb5f86c0280278fa89a0066c3760edc326e935ce0b1550d315d/ruff-0.12.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4987b8f4ceadf597c927beee65a5eaf994c6e2b631df963f86d8ad1bdea99342", size = 11717609, upload-time = "2025-07-03T16:40:10.836Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/23/15/f6751c07c21ca10e3f4a51ea495ca975ad936d780c347d9808bcedbd7182/ruff-0.9.4-py3-none-win32.whl", hash = "sha256:db1192ddda2200671f9ef61d9597fcef89d934f5d1705e571a93a67fb13a4402", size = 9852302, upload-time = "2025-01-30T18:09:40.013Z" },
|
{ url = "https://files.pythonhosted.org/packages/51/de/8589fa724590faa057e5a6d171e7f2f6cffe3287406ef40e49c682c07d89/ruff-0.12.2-py3-none-win32.whl", hash = "sha256:369ffb69b70cd55b6c3fc453b9492d98aed98062db9fec828cdfd069555f5f1a", size = 10523823, upload-time = "2025-07-03T16:40:13.203Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/12/41/2d2d2c6a72e62566f730e49254f602dfed23019c33b5b21ea8f8917315a1/ruff-0.9.4-py3-none-win_amd64.whl", hash = "sha256:05bebf4cdbe3ef75430d26c375773978950bbf4ee3c95ccb5448940dc092408e", size = 10850051, upload-time = "2025-01-30T18:09:43.42Z" },
|
{ url = "https://files.pythonhosted.org/packages/94/47/8abf129102ae4c90cba0c2199a1a9b0fa896f6f806238d6f8c14448cc748/ruff-0.12.2-py3-none-win_amd64.whl", hash = "sha256:dca8a3b6d6dc9810ed8f328d406516bf4d660c00caeaef36eb831cf4871b0639", size = 11629831, upload-time = "2025-07-03T16:40:15.478Z" },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c6/e6/3d6ec3bc3d254e7f005c543a661a41c3e788976d0e52a1ada195bd664344/ruff-0.9.4-py3-none-win_arm64.whl", hash = "sha256:585792f1e81509e38ac5123492f8875fbc36f3ede8185af0a26df348e5154f41", size = 10078251, upload-time = "2025-01-30T18:09:48.01Z" },
|
{ url = "https://files.pythonhosted.org/packages/e2/1f/72d2946e3cc7456bb837e88000eb3437e55f80db339c840c04015a11115d/ruff-0.12.2-py3-none-win_arm64.whl", hash = "sha256:48d6c6bfb4761df68bc05ae630e24f506755e702d4fb08f08460be778c7ccb12", size = 10735334, upload-time = "2025-07-03T16:40:17.677Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
Loading…
Reference in New Issue
Block a user