diff --git a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py index 7a6233b2dc7..23a7c72d459 100644 --- a/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py +++ b/libs/partners/ollama/tests/unit_tests/test_chat_models_v1.py @@ -1,5 +1,11 @@ """Unit tests for ChatOllamaV1.""" +import logging +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from langchain_core.exceptions import OutputParserException from langchain_core.messages.content_blocks import ( create_image_block, create_text_block, @@ -15,7 +21,10 @@ from langchain_ollama._compat import ( _convert_from_v1_to_ollama_format, _convert_to_v1_from_ollama_format, ) -from langchain_ollama.chat_models_v1 import ChatOllamaV1 +from langchain_ollama.chat_models_v1 import ( + ChatOllamaV1, + _parse_json_string, +) MODEL_NAME = "llama3.1" @@ -195,3 +204,199 @@ class TestChatOllamaV1(ChatModelV1UnitTests): # Should return a bound model assert bound_llm is not None + + +@patch("langchain_ollama.chat_models.validate_model") +def test_validate_model_on_init(mock_validate_model: Any) -> None: + """Test that the model is validated on initialization when requested.""" + # Test that validate_model is called when validate_model_on_init=True + ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=True) + mock_validate_model.assert_called_once() + mock_validate_model.reset_mock() + + # Test that validate_model is NOT called when validate_model_on_init=False + ChatOllamaV1(model=MODEL_NAME, validate_model_on_init=False) + mock_validate_model.assert_not_called() + + # Test that validate_model is NOT called by default + ChatOllamaV1(model=MODEL_NAME) + mock_validate_model.assert_not_called() + + +# Define a dummy raw_tool_call for the function signature +dummy_raw_tool_call = { + "function": {"name": "test_func", "arguments": ""}, +} + + +@pytest.mark.parametrize( + "input_string, expected_output", + [ + # Case 1: Standard double-quoted JSON + ('{"key": "value", "number": 123}', {"key": "value", "number": 123}), + # Case 2: Single-quoted string (the original bug) + ("{'key': 'value', 'number': 123}", {"key": "value", "number": 123}), + # Case 3: String with an internal apostrophe + ('{"text": "It\'s a great test!"}', {"text": "It's a great test!"}), + # Case 4: Mixed quotes that ast can handle + ("{'text': \"It's a great test!\"}", {"text": "It's a great test!"}), + ], +) +def test_parse_json_string_success_cases( + input_string: str, expected_output: Any +) -> None: + """Tests that `_parse_json_string` correctly parses valid and fixable strings.""" + raw_tool_call = {"function": {"name": "test_func", "arguments": input_string}} + result = _parse_json_string(input_string, raw_tool_call=raw_tool_call, skip=False) + assert result == expected_output + + +def test_parse_json_string_failure_case_raises_exception() -> None: + """Tests that `_parse_json_string` raises an exception for malformed strings.""" + malformed_string = "{'key': 'value',,}" + raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}} + with pytest.raises(OutputParserException): + _parse_json_string( + malformed_string, + raw_tool_call=raw_tool_call, + skip=False, + ) + + +def test_parse_json_string_skip_returns_input_on_failure() -> None: + """Tests that `skip=True` returns the original string on parse failure.""" + malformed_string = "{'not': valid,,,}" + raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}} + result = _parse_json_string( + malformed_string, + raw_tool_call=raw_tool_call, + skip=True, + ) + assert result == malformed_string + + +def test_load_response_with_empty_content_is_skipped( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test that load responses with empty content log a warning and are skipped.""" + load_only_response = [ + { + "model": "test-model", + "created_at": "2025-01-01T00:00:00.000000000Z", + "done": True, + "done_reason": "load", + "message": {"role": "assistant", "content": ""}, + } + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = load_only_response + + llm = ChatOllamaV1(model="test-model") + + with ( + caplog.at_level(logging.WARNING), + pytest.raises(ValueError, match="No data received from Ollama stream"), + ): + llm.invoke([HumanMessageV1("Hello")]) + + assert "Ollama returned empty response with done_reason='load'" in caplog.text + + +def test_load_response_with_whitespace_content_is_skipped( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test load responses w/ only whitespace content log a warning and are skipped.""" + load_whitespace_response = [ + { + "model": "test-model", + "created_at": "2025-01-01T00:00:00.000000000Z", + "done": True, + "done_reason": "load", + "message": {"role": "assistant", "content": " \n \t "}, + } + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = load_whitespace_response + + llm = ChatOllamaV1(model="test-model") + + with ( + caplog.at_level(logging.WARNING), + pytest.raises(ValueError, match="No data received from Ollama stream"), + ): + llm.invoke([HumanMessageV1("Hello")]) + assert "Ollama returned empty response with done_reason='load'" in caplog.text + + +def test_load_followed_by_content_response( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test load responses log a warning and are skipped when followed by content.""" + load_then_content_response = [ + { + "model": "test-model", + "created_at": "2025-01-01T00:00:00.000000000Z", + "done": True, + "done_reason": "load", + "message": {"role": "assistant", "content": ""}, + }, + { + "model": "test-model", + "created_at": "2025-01-01T00:00:01.000000000Z", + "done": True, + "done_reason": "stop", + "message": { + "role": "assistant", + "content": "Hello! How can I help you today?", + }, + }, + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = load_then_content_response + + llm = ChatOllamaV1(model="test-model") + + with caplog.at_level(logging.WARNING): + result = llm.invoke([HumanMessageV1("Hello")]) + + assert "Ollama returned empty response with done_reason='load'" in caplog.text + assert result.content == "Hello! How can I help you today?" + assert result.response_metadata.get("done_reason") == "stop" + + +def test_load_response_with_actual_content_is_not_skipped( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test load responses with actual content are NOT skipped and log no warning.""" + load_with_content_response = [ + { + "model": "test-model", + "created_at": "2025-01-01T00:00:00.000000000Z", + "done": True, + "done_reason": "load", + "message": {"role": "assistant", "content": "This is actual content"}, + } + ] + + with patch("langchain_ollama.chat_models.Client") as mock_client_class: + mock_client = MagicMock() + mock_client_class.return_value = mock_client + mock_client.chat.return_value = load_with_content_response + + llm = ChatOllamaV1(model="test-model") + + with caplog.at_level(logging.WARNING): + result = llm.invoke([HumanMessageV1("Hello")]) + + assert result.content == "This is actual content" + assert result.response_metadata.get("done_reason") == "load" + assert not caplog.text