mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-14 23:26:34 +00:00
fix(ollama): robustly parse single-quoted JSON in tool calls (#32109)
**Description:** This PR makes argument parsing for Ollama tool calls more robust. Some LLMs—including Ollama—may return arguments as Python-style dictionaries with single quotes (e.g., `{'a': 1}`), which are not valid JSON and previously caused parsing to fail. The updated `_parse_json_string` method in `langchain_ollama.chat_models` now attempts standard JSON parsing and, if that fails, falls back to `ast.literal_eval` for safe evaluation of Python-style dictionaries. This improves interoperability with LLMs and fixes a common usability issue for tool-based agents. **Issue:** Closes #30910 **Dependencies:** None **Tests:** - Added new unit tests for double-quoted JSON, single-quoted dicts, mixed quoting, and malformed/failure cases. - All tests pass locally, including new coverage for single-quoted inputs. **Notes:** - No breaking changes. - No new dependencies introduced. - Code is formatted and linted (`ruff format`, `ruff check`). - If maintainers have suggestions for further improvements, I’m happy to revise! Thank you for maintaining LangChain! Looking forward to your feedback.
This commit is contained in:
parent
6794422b85
commit
8e4396bb32
@ -5,7 +5,7 @@ build-backend = "pdm.backend"
|
|||||||
[project]
|
[project]
|
||||||
authors = []
|
authors = []
|
||||||
license = { text = "MIT" }
|
license = { text = "MIT" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9, <4.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"langchain-core<1.0.0,>=0.3.66",
|
"langchain-core<1.0.0,>=0.3.66",
|
||||||
"langchain-text-splitters<1.0.0,>=0.3.8",
|
"langchain-text-splitters<1.0.0,>=0.3.8",
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
import json
|
import json
|
||||||
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
@ -77,31 +78,43 @@ def _get_usage_metadata_from_generation_info(
|
|||||||
|
|
||||||
def _parse_json_string(
|
def _parse_json_string(
|
||||||
json_string: str,
|
json_string: str,
|
||||||
|
*,
|
||||||
raw_tool_call: dict[str, Any],
|
raw_tool_call: dict[str, Any],
|
||||||
skip: bool, # noqa: FBT001
|
skip: bool,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Attempt to parse a JSON string for tool calling.
|
"""Attempt to parse a JSON string for tool calling.
|
||||||
|
|
||||||
|
It first tries to use the standard json.loads. If that fails, it falls
|
||||||
|
back to ast.literal_eval to safely parse Python literals, which is more
|
||||||
|
robust against models using single quotes or containing apostrophes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
json_string: JSON string to parse.
|
json_string: JSON string to parse.
|
||||||
skip: Whether to ignore parsing errors and return the value anyways.
|
|
||||||
raw_tool_call: Raw tool call to include in error message.
|
raw_tool_call: Raw tool call to include in error message.
|
||||||
|
skip: Whether to ignore parsing errors and return the value anyways.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parsed JSON string.
|
The parsed JSON string or Python literal.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
OutputParserException: If the JSON string wrong invalid and skip=False.
|
OutputParserException: If the string is invalid and skip=False.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return json.loads(json_string)
|
return json.loads(json_string)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError:
|
||||||
|
try:
|
||||||
|
# Use ast.literal_eval to safely parse Python-style dicts
|
||||||
|
# (e.g. with single quotes)
|
||||||
|
return ast.literal_eval(json_string)
|
||||||
|
except (SyntaxError, ValueError) as e:
|
||||||
|
# If both fail, and we're not skipping, raise an informative error.
|
||||||
if skip:
|
if skip:
|
||||||
return json_string
|
return json_string
|
||||||
msg = (
|
msg = (
|
||||||
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||||
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
f"{raw_tool_call['function']['arguments']}"
|
||||||
f"Received JSONDecodeError {e}"
|
"\n\nare not valid JSON or a Python literal. "
|
||||||
|
f"Received error: {e}"
|
||||||
)
|
)
|
||||||
raise OutputParserException(msg) from e
|
raise OutputParserException(msg) from e
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
|
@ -8,10 +8,15 @@ from unittest.mock import patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import Client, Request, Response
|
from httpx import Client, Request, Response
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.messages import ChatMessage
|
from langchain_core.messages import ChatMessage
|
||||||
from langchain_tests.unit_tests import ChatModelUnitTests
|
from langchain_tests.unit_tests import ChatModelUnitTests
|
||||||
|
|
||||||
from langchain_ollama.chat_models import ChatOllama, _parse_arguments_from_tool_call
|
from langchain_ollama.chat_models import (
|
||||||
|
ChatOllama,
|
||||||
|
_parse_arguments_from_tool_call,
|
||||||
|
_parse_json_string,
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_NAME = "llama3.1"
|
MODEL_NAME = "llama3.1"
|
||||||
|
|
||||||
@ -49,13 +54,11 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
|||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
) -> None:
|
) -> None:
|
||||||
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
|
monkeypatch.setattr(Client, "stream", _mock_httpx_client_stream)
|
||||||
|
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
format=None,
|
format=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
role="somerandomrole",
|
role="somerandomrole",
|
||||||
@ -64,7 +67,6 @@ def test_arbitrary_roles_accepted_in_chatmessages(
|
|||||||
ChatMessage(role="control", content="thinking"),
|
ChatMessage(role="control", content="thinking"),
|
||||||
ChatMessage(role="user", content="What is the meaning of life?"),
|
ChatMessage(role="user", content="What is the meaning of life?"),
|
||||||
]
|
]
|
||||||
|
|
||||||
llm.invoke(messages)
|
llm.invoke(messages)
|
||||||
|
|
||||||
|
|
||||||
@ -83,3 +85,58 @@ def test_validate_model_on_init(mock_validate_model: Any) -> None:
|
|||||||
# Test that validate_model is NOT called by default
|
# Test that validate_model is NOT called by default
|
||||||
ChatOllama(model=MODEL_NAME)
|
ChatOllama(model=MODEL_NAME)
|
||||||
mock_validate_model.assert_not_called()
|
mock_validate_model.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# Define a dummy raw_tool_call for the function signature
|
||||||
|
dummy_raw_tool_call = {
|
||||||
|
"function": {"name": "test_func", "arguments": ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Regression tests for tool-call argument parsing (see #30910) ---
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_string, expected_output",
|
||||||
|
[
|
||||||
|
# Case 1: Standard double-quoted JSON
|
||||||
|
('{"key": "value", "number": 123}', {"key": "value", "number": 123}),
|
||||||
|
# Case 2: Single-quoted string (the original bug)
|
||||||
|
("{'key': 'value', 'number': 123}", {"key": "value", "number": 123}),
|
||||||
|
# Case 3: String with an internal apostrophe
|
||||||
|
('{"text": "It\'s a great test!"}', {"text": "It's a great test!"}),
|
||||||
|
# Case 4: Mixed quotes that ast can handle
|
||||||
|
("{'text': \"It's a great test!\"}", {"text": "It's a great test!"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_parse_json_string_success_cases(
|
||||||
|
input_string: str, expected_output: Any
|
||||||
|
) -> None:
|
||||||
|
"""Tests that _parse_json_string correctly parses valid and fixable strings."""
|
||||||
|
raw_tool_call = {"function": {"name": "test_func", "arguments": input_string}}
|
||||||
|
result = _parse_json_string(input_string, raw_tool_call=raw_tool_call, skip=False)
|
||||||
|
assert result == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_string_failure_case_raises_exception() -> None:
|
||||||
|
"""Tests that _parse_json_string raises an exception for truly malformed strings."""
|
||||||
|
malformed_string = "{'key': 'value',,}"
|
||||||
|
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||||
|
with pytest.raises(OutputParserException):
|
||||||
|
_parse_json_string(
|
||||||
|
malformed_string,
|
||||||
|
raw_tool_call=raw_tool_call,
|
||||||
|
skip=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_json_string_skip_returns_input_on_failure() -> None:
|
||||||
|
"""Tests that skip=True returns the original string on parse failure."""
|
||||||
|
malformed_string = "{'not': valid,,,}"
|
||||||
|
raw_tool_call = {"function": {"name": "test_func", "arguments": malformed_string}}
|
||||||
|
result = _parse_json_string(
|
||||||
|
malformed_string,
|
||||||
|
raw_tool_call=raw_tool_call,
|
||||||
|
skip=True,
|
||||||
|
)
|
||||||
|
assert result == malformed_string
|
||||||
|
Loading…
Reference in New Issue
Block a user