mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
"""Test ChatMistral chat model."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any
|
|
|
|
import pytest
|
|
from httpx import ReadTimeout
|
|
from langchain_core.messages import AIMessageChunk, BaseMessageChunk
|
|
from pydantic import BaseModel
|
|
from typing_extensions import TypedDict
|
|
|
|
from langchain_mistralai.chat_models import ChatMistralAI
|
|
|
|
|
|
async def test_astream() -> None:
|
|
"""Test streaming tokens from ChatMistralAI."""
|
|
llm = ChatMistralAI()
|
|
|
|
full: BaseMessageChunk | None = None
|
|
chunks_with_token_counts = 0
|
|
chunks_with_response_metadata = 0
|
|
async for token in llm.astream("Hello"):
|
|
assert isinstance(token, AIMessageChunk)
|
|
assert isinstance(token.content, str)
|
|
full = token if full is None else full + token
|
|
if token.usage_metadata is not None:
|
|
chunks_with_token_counts += 1
|
|
if token.response_metadata:
|
|
chunks_with_response_metadata += 1
|
|
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
|
|
msg = (
|
|
"Expected exactly one chunk with token counts or response_metadata. "
|
|
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
|
|
"this is behaving properly."
|
|
)
|
|
raise AssertionError(msg)
|
|
assert isinstance(full, AIMessageChunk)
|
|
assert full.usage_metadata is not None
|
|
assert full.usage_metadata["input_tokens"] > 0
|
|
assert full.usage_metadata["output_tokens"] > 0
|
|
assert (
|
|
full.usage_metadata["input_tokens"] + full.usage_metadata["output_tokens"]
|
|
== full.usage_metadata["total_tokens"]
|
|
)
|
|
assert isinstance(full.response_metadata["model_name"], str)
|
|
assert full.response_metadata["model_name"]
|
|
|
|
|
|
class Book(BaseModel):
|
|
name: str
|
|
authors: list[str]
|
|
|
|
|
|
class BookDict(TypedDict):
|
|
name: str
|
|
authors: list[str]
|
|
|
|
|
|
def _check_parsed_result(result: Any, schema: Any) -> None:
|
|
if schema == Book:
|
|
assert isinstance(result, Book)
|
|
else:
|
|
assert all(key in ["name", "authors"] for key in result)
|
|
|
|
|
|
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
|
def test_structured_output_json_schema(schema: Any) -> None:
|
|
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
|
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
|
|
|
messages = [
|
|
{"role": "system", "content": "Extract the book's information."},
|
|
{
|
|
"role": "user",
|
|
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
|
},
|
|
]
|
|
# Test invoke
|
|
result = structured_llm.invoke(messages)
|
|
_check_parsed_result(result, schema)
|
|
|
|
# Test stream
|
|
for chunk in structured_llm.stream(messages):
|
|
_check_parsed_result(chunk, schema)
|
|
|
|
|
|
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])
|
|
async def test_structured_output_json_schema_async(schema: Any) -> None:
|
|
llm = ChatMistralAI(model="ministral-8b-latest") # type: ignore[call-arg]
|
|
structured_llm = llm.with_structured_output(schema, method="json_schema")
|
|
|
|
messages = [
|
|
{"role": "system", "content": "Extract the book's information."},
|
|
{
|
|
"role": "user",
|
|
"content": "I recently read 'To Kill a Mockingbird' by Harper Lee.",
|
|
},
|
|
]
|
|
# Test invoke
|
|
result = await structured_llm.ainvoke(messages)
|
|
_check_parsed_result(result, schema)
|
|
|
|
# Test stream
|
|
async for chunk in structured_llm.astream(messages):
|
|
_check_parsed_result(chunk, schema)
|
|
|
|
|
|
def test_retry_parameters(caplog: pytest.LogCaptureFixture) -> None:
|
|
"""Test that retry parameters are honored in ChatMistralAI."""
|
|
# Create a model with intentionally short timeout and multiple retries
|
|
mistral = ChatMistralAI(
|
|
timeout=1, # Very short timeout to trigger timeouts
|
|
max_retries=3, # Should retry 3 times
|
|
)
|
|
|
|
# Simple test input that should take longer than 1 second to process
|
|
test_input = "Write a 2 sentence story about a cat"
|
|
|
|
# Measure start time
|
|
t0 = time.time()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
try:
|
|
# Try to get a response
|
|
response = mistral.invoke(test_input)
|
|
|
|
# If successful, validate the response
|
|
elapsed_time = time.time() - t0
|
|
logger.info("Request succeeded in %.2f seconds", elapsed_time)
|
|
# Check that we got a valid response
|
|
assert response.content
|
|
assert isinstance(response.content, str)
|
|
assert "cat" in response.content.lower()
|
|
|
|
except ReadTimeout:
|
|
elapsed_time = time.time() - t0
|
|
logger.info("Request timed out after %.2f seconds", elapsed_time)
|
|
assert elapsed_time >= 3.0
|
|
pytest.skip("Test timed out as expected with short timeout")
|
|
except Exception:
|
|
logger.exception("Unexpected exception")
|
|
raise
|