mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 13:00:34 +00:00
partners/ollama: fix tool calling with nested schemas (#28225)
## Description This PR addresses the following: **Fixes Issue #25343:** - Adds additional logic to parse shallowly nested JSON-encoded strings in tool call arguments, allowing for proper parsing of responses like that of Llama3.1 and 3.2 with nested schemas. **Adds Integration Test for Fix:** - Adds a Ollama specific integration test to ensure the issue is resolved and to prevent regressions in the future. **Fixes Failing Integration Tests:** - Fixes failing integration tests (even prior to changes) caused by `llama3-groq-tool-use` model. Previously, tests`test_structured_output_async` and `test_structured_output_optional_param` failed due to the model not issuing a tool call in the response. Resolved by switching to `llama3.1`. ## Issue Fixes #25343. ## Dependencies No dependencies. ____ Done in collaboration with @ishaan-upadhyay @mirajismail @ZackSteine.
This commit is contained in:
parent
bb83abd037
commit
607c60a594
@ -1,5 +1,6 @@
|
|||||||
"""Ollama chat models."""
|
"""Ollama chat models."""
|
||||||
|
|
||||||
|
import json
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
@ -21,6 +22,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
|
||||||
|
from langchain_core.exceptions import OutputParserException
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -60,6 +62,72 @@ def _get_usage_metadata_from_generation_info(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_json_string(
|
||||||
|
json_string: str, raw_tool_call: dict[str, Any], skip: bool
|
||||||
|
) -> Any:
|
||||||
|
"""Attempt to parse a JSON string for tool calling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_string: JSON string to parse.
|
||||||
|
skip: Whether to ignore parsing errors and return the value anyways.
|
||||||
|
raw_tool_call: Raw tool call to include in error message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The parsed JSON string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
OutputParserException: If the JSON string wrong invalid and skip=False.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return json.loads(json_string)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
if skip:
|
||||||
|
return json_string
|
||||||
|
msg = (
|
||||||
|
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||||
|
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||||
|
f"Received JSONDecodeError {e}"
|
||||||
|
)
|
||||||
|
raise OutputParserException(msg) from e
|
||||||
|
except TypeError as e:
|
||||||
|
if skip:
|
||||||
|
return json_string
|
||||||
|
msg = (
|
||||||
|
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
|
||||||
|
f"{raw_tool_call['function']['arguments']}\n\nare not a string or a "
|
||||||
|
f"dictionary. Received TypeError {e}"
|
||||||
|
)
|
||||||
|
raise OutputParserException(msg) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_arguments_from_tool_call(
|
||||||
|
raw_tool_call: dict[str, Any],
|
||||||
|
) -> Optional[dict[str, Any]]:
|
||||||
|
"""Parse arguments by trying to parse any shallowly nested string-encoded JSON.
|
||||||
|
|
||||||
|
Band-aid fix for issue in Ollama with inconsistent tool call argument structure.
|
||||||
|
Should be removed/changed if fixed upstream.
|
||||||
|
See https://github.com/ollama/ollama/issues/6155
|
||||||
|
"""
|
||||||
|
if "function" not in raw_tool_call:
|
||||||
|
return None
|
||||||
|
arguments = raw_tool_call["function"]["arguments"]
|
||||||
|
parsed_arguments = {}
|
||||||
|
if isinstance(arguments, dict):
|
||||||
|
for key, value in arguments.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
parsed_arguments[key] = _parse_json_string(
|
||||||
|
value, skip=True, raw_tool_call=raw_tool_call
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parsed_arguments[key] = value
|
||||||
|
else:
|
||||||
|
parsed_arguments = _parse_json_string(
|
||||||
|
arguments, skip=False, raw_tool_call=raw_tool_call
|
||||||
|
)
|
||||||
|
return parsed_arguments
|
||||||
|
|
||||||
|
|
||||||
def _get_tool_calls_from_response(
|
def _get_tool_calls_from_response(
|
||||||
response: Mapping[str, Any],
|
response: Mapping[str, Any],
|
||||||
) -> List[ToolCall]:
|
) -> List[ToolCall]:
|
||||||
@ -72,7 +140,7 @@ def _get_tool_calls_from_response(
|
|||||||
tool_call(
|
tool_call(
|
||||||
id=str(uuid4()),
|
id=str(uuid4()),
|
||||||
name=tc["function"]["name"],
|
name=tc["function"]["name"],
|
||||||
args=tc["function"]["arguments"],
|
args=_parse_arguments_from_tool_call(tc) or {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return tool_calls
|
return tool_calls
|
||||||
|
@ -0,0 +1,41 @@
|
|||||||
|
"""Ollama specific chat model integration tests"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("model"), [("llama3.1")])
|
||||||
|
def test_structured_output_deeply_nested(model: str) -> None:
|
||||||
|
"""Test to verify structured output with a nested objects."""
|
||||||
|
llm = ChatOllama(model=model, temperature=0)
|
||||||
|
|
||||||
|
class Person(BaseModel):
|
||||||
|
"""Information about a person."""
|
||||||
|
|
||||||
|
name: Optional[str] = Field(default=None, description="The name of the person")
|
||||||
|
hair_color: Optional[str] = Field(
|
||||||
|
default=None, description="The color of the person's hair if known"
|
||||||
|
)
|
||||||
|
height_in_meters: Optional[str] = Field(
|
||||||
|
default=None, description="Height measured in meters"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Data(BaseModel):
|
||||||
|
"""Extracted data about people."""
|
||||||
|
|
||||||
|
people: List[Person]
|
||||||
|
|
||||||
|
chat = llm.with_structured_output(Data) # type: ignore[arg-type]
|
||||||
|
text = (
|
||||||
|
"Alan Smith is 6 feet tall and has blond hair."
|
||||||
|
"Alan Poe is 3 feet tall and has grey hair."
|
||||||
|
)
|
||||||
|
result = chat.invoke(text)
|
||||||
|
assert isinstance(result, Data)
|
||||||
|
|
||||||
|
for chunk in chat.stream(text):
|
||||||
|
assert isinstance(chunk, Data)
|
@ -1,4 +1,4 @@
|
|||||||
"""Test chat model integration."""
|
"""Test chat model integration using standard integration tests."""
|
||||||
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ class TestChatOllama(ChatModelIntegrationTests):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def chat_model_params(self) -> dict:
|
def chat_model_params(self) -> dict:
|
||||||
return {"model": "llama3-groq-tool-use"}
|
return {"model": "llama3.1"}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_image_inputs(self) -> bool:
|
def supports_image_inputs(self) -> bool:
|
Loading…
Reference in New Issue
Block a user