mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-15 23:57:21 +00:00
continue to match v0
This commit is contained in:
parent
99687ce626
commit
5b60e9362e
@ -1,5 +1,11 @@
|
||||
"""Unit tests for ChatOllamaV1."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages.content_blocks import (
|
||||
create_image_block,
|
||||
create_text_block,
|
||||
@ -15,7 +21,10 @@ from langchain_ollama._compat import (
|
||||
_convert_from_v1_to_ollama_format,
|
||||
_convert_to_v1_from_ollama_format,
|
||||
)
|
||||
from langchain_ollama.chat_models_v1 import ChatOllamaV1
|
||||
from langchain_ollama.chat_models_v1 import (
|
||||
ChatOllamaV1,
|
||||
_parse_json_string,
|
||||
)
|
||||
|
||||
MODEL_NAME = "llama3.1"
|
||||
|
||||
@ -195,3 +204,199 @@ class TestChatOllamaV1(ChatModelV1UnitTests):
|
||||
|
||||
# Should return a bound model
|
||||
assert bound_llm is not None
|
||||
|
||||
|
||||
@patch("langchain_ollama.chat_models.validate_model")
|
||||
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
||||
"""Test that the model is validated on initialization when requested."""
|
||||
# Test that validate_model is called when validate_model_on_init=True
|
||||
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True)
|
||||
mock_validate_model.assert_called_once()
|
||||
mock_validate_model.reset_mock()
|
||||
|
||||
# Test that validate_model is NOT called when validate_model_on_init=False
|
||||
ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=False)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
# Test that validate_model is NOT called by default
|
||||
ChatOllamaV1(model=MODEL_NAME)
|
||||
mock_validate_model.assert_not_called()
|
||||
|
||||
|
||||
# Define a dummy raw_tool_call for the function signature
|
||||
dummy_raw_tool_call = {
|
||||
"function": {"name": "test_func", "arguments": ""},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_string, expected_output",
|
||||
[
|
||||
# Case 1: Standard double-quoted JSON
|
||||
('{"key": "value", "number": 123}', {"key": "value", "number": 123}),
|
||||
# Case 2: Single-quoted string (the original bug)
|
||||
("{'key': 'value', 'number': 123}", {"key": "value", "number": 123}),
|
||||
# Case 3: String with an internal apostrophe
|
||||
('{"text": "It\'s a great test!"}', {"text": "It's a great test!"}),
|
||||
# Case 4: Mixed quotes that ast can handle
|
||||
("{'text': \"It's a great test!\"}", {"text": "It's a great test!"}),
|
||||
],
|
||||
)
|
||||
def test_parse_json_string_success_cases(
|
||||
input_string: str, expected_output: Any
|
||||
) -> None:
|
||||
"""Tests that `_parse_json_string` correctly parses valid and fixable strings."""
|
||||
raw_tool_call = {"function": {"name": "test_func", "arguments": input_string}}
|
||||
result = _parse_json_string(input_string, raw_tool_call=raw_tool_call, skip=False)
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test_parse_json_string_failure_case_raises_exception() -> None:
|
||||
"""Tests that `_parse_json_string` raises an exception for malformed strings."""
|
||||
malformed_string = "{'key': 'value',,}"
|
||||
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||
with pytest.raises(OutputParserException):
|
||||
_parse_json_string(
|
||||
malformed_string,
|
||||
raw_tool_call=raw_tool_call,
|
||||
skip=False,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_json_string_skip_returns_input_on_failure() -> None:
|
||||
"""Tests that `skip=True` returns the original string on parse failure."""
|
||||
malformed_string = "{'not': valid,,,}"
|
||||
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||
result = _parse_json_string(
|
||||
malformed_string,
|
||||
raw_tool_call=raw_tool_call,
|
||||
skip=True,
|
||||
)
|
||||
assert result == malformed_string
|
||||
|
||||
|
||||
def test_load_response_with_empty_content_is_skipped(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test that load responses with empty content log a warning and are skipped."""
|
||||
load_only_response = [
|
||||
{
|
||||
"model": "test-model",
|
||||
"created_at": "2025-01-01T00:00:00.000000000Z",
|
||||
"done": True,
|
||||
"done_reason": "load",
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_only_response
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
||||
):
|
||||
llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||
|
||||
|
||||
def test_load_response_with_whitespace_content_is_skipped(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test load responses w/ only whitespace content log a warning and are skipped."""
|
||||
load_whitespace_response = [
|
||||
{
|
||||
"model": "test-model",
|
||||
"created_at": "2025-01-01T00:00:00.000000000Z",
|
||||
"done": True,
|
||||
"done_reason": "load",
|
||||
"message": {"role": "assistant", "content": " \n \t "},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_whitespace_response
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with (
|
||||
caplog.at_level(logging.WARNING),
|
||||
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
||||
):
|
||||
llm.invoke([HumanMessageV1("Hello")])
|
||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||
|
||||
|
||||
def test_load_followed_by_content_response(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test load responses log a warning and are skipped when followed by content."""
|
||||
load_then_content_response = [
|
||||
{
|
||||
"model": "test-model",
|
||||
"created_at": "2025-01-01T00:00:00.000000000Z",
|
||||
"done": True,
|
||||
"done_reason": "load",
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
},
|
||||
{
|
||||
"model": "test-model",
|
||||
"created_at": "2025-01-01T00:00:01.000000000Z",
|
||||
"done": True,
|
||||
"done_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I help you today?",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_then_content_response
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
||||
assert result.content == "Hello! How can I help you today?"
|
||||
assert result.response_metadata.get("done_reason") == "stop"
|
||||
|
||||
|
||||
def test_load_response_with_actual_content_is_not_skipped(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
"""Test load responses with actual content are NOT skipped and log no warning."""
|
||||
load_with_content_response = [
|
||||
{
|
||||
"model": "test-model",
|
||||
"created_at": "2025-01-01T00:00:00.000000000Z",
|
||||
"done": True,
|
||||
"done_reason": "load",
|
||||
"message": {"role": "assistant", "content": "This is actual content"},
|
||||
}
|
||||
]
|
||||
|
||||
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value = mock_client
|
||||
mock_client.chat.return_value = load_with_content_response
|
||||
|
||||
llm = ChatOllamaV1(model="test-model")
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = llm.invoke([HumanMessageV1("Hello")])
|
||||
|
||||
assert result.content == "This is actual content"
|
||||
assert result.response_metadata.get("done_reason") == "load"
|
||||
assert not caplog.text
|
||||
|
Loading…
Reference in New Issue
Block a user