mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-27 19:46:55 +00:00
**partners: Enable max_retries in ChatMistralAI** **Description** - This pull request reactivates the retry logic in the completion_with_retry method of the ChatMistralAI class, restoring the intended functionality of the previously ineffective max_retries parameter. New unit test that mocks failed/successful retry calls and an integration test to confirm end-to-end functionality. **Issue** - Closes #30362 **Dependencies** - No additional dependencies required Co-authored-by: andrasfe <andrasf94@gmail.com>
317 lines
10 KiB
Python
317 lines
10 KiB
Python
"""Test MistralAI Chat API wrapper."""
|
|
|
|
import os
|
|
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ChatMessage,
|
|
HumanMessage,
|
|
InvalidToolCall,
|
|
SystemMessage,
|
|
ToolCall,
|
|
)
|
|
from pydantic import SecretStr
|
|
|
|
from langchain_mistralai.chat_models import ( # type: ignore[import]
|
|
ChatMistralAI,
|
|
_convert_message_to_mistral_chat_message,
|
|
_convert_mistral_chat_message_to_message,
|
|
_convert_tool_call_id_to_mistral_compatible,
|
|
_is_valid_mistral_tool_call_id,
|
|
)
|
|
|
|
os.environ["MISTRAL_API_KEY"] = "foo"
|
|
|
|
|
|
def test_mistralai_model_param() -> None:
|
|
llm = ChatMistralAI(model="foo") # type: ignore[call-arg]
|
|
assert llm.model == "foo"
|
|
|
|
|
|
def test_mistralai_initialization() -> None:
|
|
"""Test ChatMistralAI initialization."""
|
|
# Verify that ChatMistralAI can be initialized using a secret key provided
|
|
# as a parameter rather than an environment variable.
|
|
for model in [
|
|
ChatMistralAI(model="test", mistral_api_key="test"), # type: ignore[call-arg, call-arg]
|
|
ChatMistralAI(model="test", api_key="test"), # type: ignore[call-arg, arg-type]
|
|
]:
|
|
assert cast(SecretStr, model.mistral_api_key).get_secret_value() == "test"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"model,expected_url",
|
|
[
|
|
(ChatMistralAI(model="test"), "https://api.mistral.ai/v1"), # type: ignore[call-arg, arg-type]
|
|
(ChatMistralAI(model="test", endpoint="baz"), "baz"), # type: ignore[call-arg, arg-type]
|
|
],
|
|
)
|
|
def test_mistralai_initialization_baseurl(
|
|
model: ChatMistralAI, expected_url: str
|
|
) -> None:
|
|
"""Test ChatMistralAI initialization."""
|
|
# Verify that ChatMistralAI can be initialized providing endpoint, but also
|
|
# with default
|
|
|
|
assert model.endpoint == expected_url
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_var_name",
|
|
[
|
|
("MISTRAL_BASE_URL"),
|
|
],
|
|
)
|
|
def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
|
|
"""Test ChatMistralAI initialization."""
|
|
# Verify that ChatMistralAI can be initialized using env variable
|
|
import os
|
|
|
|
os.environ[env_var_name] = "boo"
|
|
model = ChatMistralAI(model="test") # type: ignore[call-arg]
|
|
assert model.endpoint == "boo"
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("message", "expected"),
|
|
[
|
|
(
|
|
SystemMessage(content="Hello"),
|
|
dict(role="system", content="Hello"),
|
|
),
|
|
(
|
|
HumanMessage(content="Hello"),
|
|
dict(role="user", content="Hello"),
|
|
),
|
|
(
|
|
AIMessage(content="Hello"),
|
|
dict(role="assistant", content="Hello"),
|
|
),
|
|
(
|
|
AIMessage(content="{", additional_kwargs={"prefix": True}),
|
|
dict(role="assistant", content="{", prefix=True),
|
|
),
|
|
(
|
|
ChatMessage(role="assistant", content="Hello"),
|
|
dict(role="assistant", content="Hello"),
|
|
),
|
|
],
|
|
)
|
|
def test_convert_message_to_mistral_chat_message(
|
|
message: BaseMessage, expected: Dict
|
|
) -> None:
|
|
result = _convert_message_to_mistral_chat_message(message)
|
|
assert result == expected
|
|
|
|
|
|
def _make_completion_response_from_token(token: str) -> Dict:
|
|
return dict(
|
|
id="abc123",
|
|
model="fake_model",
|
|
choices=[
|
|
dict(
|
|
index=0,
|
|
delta=dict(content=token),
|
|
finish_reason=None,
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
|
|
def it() -> Generator:
|
|
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
|
yield _make_completion_response_from_token(token)
|
|
|
|
return it()
|
|
|
|
|
|
async def mock_chat_astream(*args: Any, **kwargs: Any) -> AsyncGenerator:
|
|
async def it() -> AsyncGenerator:
|
|
for token in ["Hello", " how", " can", " I", " help", "?"]:
|
|
yield _make_completion_response_from_token(token)
|
|
|
|
return it()
|
|
|
|
|
|
class MyCustomHandler(BaseCallbackHandler):
|
|
last_token: str = ""
|
|
|
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
self.last_token = token
|
|
|
|
|
|
@patch(
|
|
"langchain_mistralai.chat_models.ChatMistralAI.completion_with_retry",
|
|
new=mock_chat_stream,
|
|
)
|
|
def test_stream_with_callback() -> None:
|
|
callback = MyCustomHandler()
|
|
chat = ChatMistralAI(callbacks=[callback])
|
|
for token in chat.stream("Hello"):
|
|
assert callback.last_token == token.content
|
|
|
|
|
|
@patch("langchain_mistralai.chat_models.acompletion_with_retry", new=mock_chat_astream)
|
|
async def test_astream_with_callback() -> None:
|
|
callback = MyCustomHandler()
|
|
chat = ChatMistralAI(callbacks=[callback])
|
|
async for token in chat.astream("Hello"):
|
|
assert callback.last_token == token.content
|
|
|
|
|
|
def test__convert_dict_to_message_tool_call() -> None:
|
|
raw_tool_call = {
|
|
"id": "ssAbar4Dr",
|
|
"function": {
|
|
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
|
"name": "GenerateUsername",
|
|
},
|
|
}
|
|
message = {"role": "assistant", "content": "", "tool_calls": [raw_tool_call]}
|
|
result = _convert_mistral_chat_message_to_message(message)
|
|
expected_output = AIMessage(
|
|
content="",
|
|
additional_kwargs={"tool_calls": [raw_tool_call]},
|
|
tool_calls=[
|
|
ToolCall(
|
|
name="GenerateUsername",
|
|
args={"name": "Sally", "hair_color": "green"},
|
|
id="ssAbar4Dr",
|
|
type="tool_call",
|
|
)
|
|
],
|
|
)
|
|
assert result == expected_output
|
|
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
|
|
|
# Test malformed tool call
|
|
raw_tool_calls = [
|
|
{
|
|
"id": "pL5rEGzxe",
|
|
"function": {
|
|
"arguments": '{"name": "Sally", "hair_color": "green"}',
|
|
"name": "GenerateUsername",
|
|
},
|
|
},
|
|
{
|
|
"id": "ssAbar4Dr",
|
|
"function": {
|
|
"arguments": "oops",
|
|
"name": "GenerateUsername",
|
|
},
|
|
},
|
|
]
|
|
message = {"role": "assistant", "content": "", "tool_calls": raw_tool_calls}
|
|
result = _convert_mistral_chat_message_to_message(message)
|
|
expected_output = AIMessage(
|
|
content="",
|
|
additional_kwargs={"tool_calls": raw_tool_calls},
|
|
invalid_tool_calls=[
|
|
InvalidToolCall(
|
|
name="GenerateUsername",
|
|
args="oops",
|
|
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)\nFor troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE ", # noqa: E501
|
|
id="ssAbar4Dr",
|
|
type="invalid_tool_call",
|
|
),
|
|
],
|
|
tool_calls=[
|
|
ToolCall(
|
|
name="GenerateUsername",
|
|
args={"name": "Sally", "hair_color": "green"},
|
|
id="pL5rEGzxe",
|
|
type="tool_call",
|
|
),
|
|
],
|
|
)
|
|
assert result == expected_output
|
|
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
|
|
|
|
|
def test_custom_token_counting() -> None:
|
|
def token_encoder(text: str) -> List[int]:
|
|
return [1, 2, 3]
|
|
|
|
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
|
|
assert llm.get_token_ids("foo") == [1, 2, 3]
|
|
|
|
|
|
def test_tool_id_conversion() -> None:
|
|
assert _is_valid_mistral_tool_call_id("ssAbar4Dr")
|
|
assert not _is_valid_mistral_tool_call_id("abc123")
|
|
assert not _is_valid_mistral_tool_call_id("call_JIIjI55tTipFFzpcP8re3BpM")
|
|
|
|
result_map = {
|
|
"ssAbar4Dr": "ssAbar4Dr",
|
|
"abc123": "pL5rEGzxe",
|
|
"call_JIIjI55tTipFFzpcP8re3BpM": "8kxAQvoED",
|
|
}
|
|
for input_id, expected_output in result_map.items():
|
|
assert _convert_tool_call_id_to_mistral_compatible(input_id) == expected_output
|
|
assert _is_valid_mistral_tool_call_id(expected_output)
|
|
|
|
|
|
def test_extra_kwargs() -> None:
|
|
# Check that foo is saved in extra_kwargs.
|
|
llm = ChatMistralAI(model="my-model", foo=3, max_tokens=10) # type: ignore[call-arg]
|
|
assert llm.max_tokens == 10
|
|
assert llm.model_kwargs == {"foo": 3}
|
|
|
|
# Test that if extra_kwargs are provided, they are added to it.
|
|
llm = ChatMistralAI(model="my-model", foo=3, model_kwargs={"bar": 2}) # type: ignore[call-arg]
|
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
|
|
|
# Test that if provided twice it errors
|
|
with pytest.raises(ValueError):
|
|
ChatMistralAI(model="my-model", foo=3, model_kwargs={"foo": 2}) # type: ignore[call-arg]
|
|
|
|
|
|
def test_retry_with_failure_then_success() -> None:
|
|
"""Test that retry mechanism works correctly when
|
|
first request fails and second succeeds."""
|
|
# Create a real ChatMistralAI instance
|
|
chat = ChatMistralAI(max_retries=3)
|
|
|
|
# Set up the actual retry mechanism (not just mocking it)
|
|
# We'll track how many times the function is called
|
|
call_count = 0
|
|
|
|
def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
|
|
if call_count == 1:
|
|
raise httpx.RequestError("Connection error", request=MagicMock())
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello!",
|
|
},
|
|
"finish_reason": "stop",
|
|
}
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 1,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 2,
|
|
},
|
|
}
|
|
return mock_response
|
|
|
|
with patch.object(chat.client, "post", side_effect=mock_post):
|
|
result = chat.invoke("Hello")
|
|
assert result.content == "Hello!"
|
|
assert call_count == 2, f"Expected 2 calls, but got {call_count}"
|