mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
## Problem
When using `ChatOllama` with `create_react_agent`, agents would
sometimes terminate prematurely with empty responses when Ollama
returned `done_reason: 'load'` responses with no content. This caused
agents to return empty `AIMessage` objects instead of actual generated
text.
```python
from langchain_ollama import ChatOllama
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import HumanMessage
llm = ChatOllama(model='qwen2.5:7b', temperature=0)
agent = create_react_agent(model=llm, tools=[])
result = agent.invoke(HumanMessage('Hello'), {"configurable": {"thread_id": "1"}})
# Before fix: AIMessage(content='', response_metadata={'done_reason': 'load'})
# Expected: AIMessage with actual generated content
```
## Root Cause
The `_iterate_over_stream` and `_aiterate_over_stream` methods treated
any response with `done: True` as final, regardless of `done_reason`.
When Ollama returns `done_reason: 'load'` with empty content, it
indicates the model was loaded but no actual generation occurred - this
should not be considered a complete response.
## Solution
Modified the streaming logic to skip responses when:
- `done: True`
- `done_reason: 'load'`
- Content is empty or contains only whitespace
This ensures agents only receive actual generated content while
preserving backward compatibility for load responses that do contain
content.
## Changes
- **`_iterate_over_stream`**: Skip empty load responses instead of
yielding them
- **`_aiterate_over_stream`**: Apply same fix to async streaming
- **Tests**: Added comprehensive test cases covering all edge cases
## Testing
All scenarios now work correctly:
- ✅ Empty load responses are skipped (fixes original issue)
- ✅ Load responses with actual content are preserved (backward
compatibility)
- ✅ Normal stop responses work unchanged
- ✅ Streaming behavior preserved
- ✅ `create_react_agent` integration fixed
Fixes #31482.
<!-- START COPILOT CODING AGENT TIPS -->
---
💡 You can make Copilot smarter by setting up custom instructions,
customizing its development environment and configuring Model Context
Protocol (MCP) servers. Learn more [Copilot coding agent
tips](https://gh.io/copilot-coding-agent-tips) in the docs.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: mdrxy <61371264+mdrxy@users.noreply.github.com>
Co-authored-by: Mason Daugherty <github@mdrxy.com>
271 lines
9.3 KiB
Python
271 lines
9.3 KiB
Python
"""Test chat model integration."""
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from httpx import Client, Request, Response
|
|
from langchain_core.exceptions import OutputParserException
|
|
from langchain_core.messages import ChatMessage, HumanMessage
|
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
|
|
|
from langchain_ollama.chat_models import (
|
|
ChatOllama,
|
|
_parse_arguments_from_tool_call,
|
|
_parse_json_string,
|
|
)
|
|
|
|
MODEL_NAME = "llama3.1"
|
|
|
|
|
|
class TestChatOllama(ChatModelUnitTests):
|
|
@property
|
|
def chat_model_class(self) -> type[ChatOllama]:
|
|
return ChatOllama
|
|
|
|
@property
|
|
def chat_model_params(self) -> dict:
|
|
return {"model": MODEL_NAME}
|
|
|
|
|
|
def test__parse_arguments_from_tool_call() -> None:
|
|
raw_response = '{"model":"sample-model","message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"get_profile_details","arguments":{"arg_1":"12345678901234567890123456"}}}]},"done":false}' # noqa: E501
|
|
raw_tool_calls = json.loads(raw_response)["message"]["tool_calls"]
|
|
response = _parse_arguments_from_tool_call(raw_tool_calls[0])
|
|
assert response is not None
|
|
assert isinstance(response["arg_1"], str)
|
|
|
|
|
|
@contextmanager
|
|
def _mock_httpx_client_stream(
|
|
*args: Any, **kwargs: Any
|
|
) -> Generator[Response, Any, Any]:
|
|
yield Response(
|
|
status_code=200,
|
|
content='{"message": {"role": "assistant", "content": "The meaning ..."}}',
|
|
request=Request(method="POST", url="http://whocares:11434"),
|
|
)
|
|
|
|
|
|
def test_arbitrary_roles_accepted_in_chatmessages(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
|
|
llm = ChatOllama(
|
|
model=MODEL_NAME,
|
|
verbose=True,
|
|
format=None,
|
|
)
|
|
messages = [
|
|
ChatMessage(
|
|
role="somerandomrole",
|
|
content="I'm ok with you adding any role message now!",
|
|
),
|
|
ChatMessage(role="control", content="thinking"),
|
|
ChatMessage(role="user", content="What is the meaning of life?"),
|
|
]
|
|
llm.invoke(messages)
|
|
|
|
|
|
@patch("langchain_ollama.chat_models.validate_model")
|
|
def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
|
"""Test that the model is validated on initialization when requested."""
|
|
# Test that validate_model is called when validate_model_on_init=True
|
|
ChatOllama(model=MODEL_NAME, validate_model_on_init=True)
|
|
mock_validate_model.assert_called_once()
|
|
mock_validate_model.reset_mock()
|
|
|
|
# Test that validate_model is NOT called when validate_model_on_init=False
|
|
ChatOllama(model=MODEL_NAME, validate_model_on_init=False)
|
|
mock_validate_model.assert_not_called()
|
|
|
|
# Test that validate_model is NOT called by default
|
|
ChatOllama(model=MODEL_NAME)
|
|
mock_validate_model.assert_not_called()
|
|
|
|
|
|
# Define a dummy raw_tool_call for the function signature
|
|
dummy_raw_tool_call = {
|
|
"function": {"name": "test_func", "arguments": ""},
|
|
}
|
|
|
|
|
|
# --- Regression tests for tool-call argument parsing (see #30910) ---
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_string, expected_output",
|
|
[
|
|
# Case 1: Standard double-quoted JSON
|
|
('{"key": "value", "number": 123}', {"key": "value", "number": 123}),
|
|
# Case 2: Single-quoted string (the original bug)
|
|
("{'key': 'value', 'number': 123}", {"key": "value", "number": 123}),
|
|
# Case 3: String with an internal apostrophe
|
|
('{"text": "It\'s a great test!"}', {"text": "It's a great test!"}),
|
|
# Case 4: Mixed quotes that ast can handle
|
|
("{'text': \"It's a great test!\"}", {"text": "It's a great test!"}),
|
|
],
|
|
)
|
|
def test_parse_json_string_success_cases(
|
|
input_string: str, expected_output: Any
|
|
) -> None:
|
|
"""Tests that _parse_json_string correctly parses valid and fixable strings."""
|
|
raw_tool_call = {"function": {"name": "test_func", "arguments": input_string}}
|
|
result = _parse_json_string(input_string, raw_tool_call=raw_tool_call, skip=False)
|
|
assert result == expected_output
|
|
|
|
|
|
def test_parse_json_string_failure_case_raises_exception() -> None:
|
|
"""Tests that _parse_json_string raises an exception for truly malformed strings."""
|
|
malformed_string = "{'key': 'value',,}"
|
|
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
|
with pytest.raises(OutputParserException):
|
|
_parse_json_string(
|
|
malformed_string,
|
|
raw_tool_call=raw_tool_call,
|
|
skip=False,
|
|
)
|
|
|
|
|
|
def test_parse_json_string_skip_returns_input_on_failure() -> None:
|
|
"""Tests that skip=True returns the original string on parse failure."""
|
|
malformed_string = "{'not': valid,,,}"
|
|
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
|
result = _parse_json_string(
|
|
malformed_string,
|
|
raw_tool_call=raw_tool_call,
|
|
skip=True,
|
|
)
|
|
assert result == malformed_string
|
|
|
|
|
|
def test_load_response_with_empty_content_is_skipped(
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test that load responses with empty content log a warning and are skipped."""
|
|
load_only_response = [
|
|
{
|
|
"model": "test-model",
|
|
"created_at": "2025-01-01T00:00:00.000000000Z",
|
|
"done": True,
|
|
"done_reason": "load",
|
|
"message": {"role": "assistant", "content": ""},
|
|
}
|
|
]
|
|
|
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.chat.return_value = load_only_response
|
|
|
|
llm = ChatOllama(model="test-model")
|
|
|
|
with (
|
|
caplog.at_level(logging.WARNING),
|
|
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
|
):
|
|
llm.invoke([HumanMessage("Hello")])
|
|
|
|
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
|
|
|
|
|
def test_load_response_with_whitespace_content_is_skipped(
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test load responses w/ only whitespace content log a warning and are skipped."""
|
|
load_whitespace_response = [
|
|
{
|
|
"model": "test-model",
|
|
"created_at": "2025-01-01T00:00:00.000000000Z",
|
|
"done": True,
|
|
"done_reason": "load",
|
|
"message": {"role": "assistant", "content": " \n \t "},
|
|
}
|
|
]
|
|
|
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.chat.return_value = load_whitespace_response
|
|
|
|
llm = ChatOllama(model="test-model")
|
|
|
|
with (
|
|
caplog.at_level(logging.WARNING),
|
|
pytest.raises(ValueError, match="No data received from Ollama stream"),
|
|
):
|
|
llm.invoke([HumanMessage("Hello")])
|
|
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
|
|
|
|
|
def test_load_followed_by_content_response(
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test load responses log a warning and are skipped when followed by content."""
|
|
load_then_content_response = [
|
|
{
|
|
"model": "test-model",
|
|
"created_at": "2025-01-01T00:00:00.000000000Z",
|
|
"done": True,
|
|
"done_reason": "load",
|
|
"message": {"role": "assistant", "content": ""},
|
|
},
|
|
{
|
|
"model": "test-model",
|
|
"created_at": "2025-01-01T00:00:01.000000000Z",
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Hello! How can I help you today?",
|
|
},
|
|
},
|
|
]
|
|
|
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.chat.return_value = load_then_content_response
|
|
|
|
llm = ChatOllama(model="test-model")
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
result = llm.invoke([HumanMessage("Hello")])
|
|
|
|
assert "Ollama returned empty response with done_reason='load'" in caplog.text
|
|
assert result.content == "Hello! How can I help you today?"
|
|
assert result.response_metadata.get("done_reason") == "stop"
|
|
|
|
|
|
def test_load_response_with_actual_content_is_not_skipped(
|
|
caplog: pytest.LogCaptureFixture,
|
|
) -> None:
|
|
"""Test load responses with actual content are NOT skipped and log no warning."""
|
|
load_with_content_response = [
|
|
{
|
|
"model": "test-model",
|
|
"created_at": "2025-01-01T00:00:00.000000000Z",
|
|
"done": True,
|
|
"done_reason": "load",
|
|
"message": {"role": "assistant", "content": "This is actual content"},
|
|
}
|
|
]
|
|
|
|
with patch("langchain_ollama.chat_models.Client") as mock_client_class:
|
|
mock_client = MagicMock()
|
|
mock_client_class.return_value = mock_client
|
|
mock_client.chat.return_value = load_with_content_response
|
|
|
|
llm = ChatOllama(model="test-model")
|
|
|
|
with caplog.at_level(logging.WARNING):
|
|
result = llm.invoke([HumanMessage("Hello")])
|
|
|
|
assert result.content == "This is actual content"
|
|
assert result.response_metadata.get("done_reason") == "load"
|
|
assert not caplog.text
|