partners/ollama: fix tool calling with nested schemas (#28225)

## Description

This PR addresses the following:

**Fixes Issue #25343:**
- Adds additional logic to parse shallowly nested JSON-encoded strings
in tool call arguments, allowing for proper parsing of responses like
that of Llama3.1 and 3.2 with nested schemas.
 
**Adds Integration Test for Fix:**
- Adds a Ollama specific integration test to ensure the issue is
resolved and to prevent regressions in the future.

**Fixes Failing Integration Tests:**
- Fixes failing integration tests (even prior to changes) caused by
`llama3-groq-tool-use` model. Previously,
tests`test_structured_output_async` and
`test_structured_output_optional_param` failed due to the model not
issuing a tool call in the response. Resolved by switching to
`llama3.1`.

## Issue
Fixes #25343.

## Dependencies
No dependencies.

____

Done in collaboration with @ishaan-upadhyay @mirajismail @ZackSteine.
This commit is contained in:
TheDannyG
2024-11-27 10:32:02 -05:00
committed by GitHub
parent bb83abd037
commit 607c60a594
3 changed files with 112 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
"""Ollama chat models."""
import json
from typing import (
Any,
AsyncIterator,
@@ -21,6 +22,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.messages import (
@@ -60,6 +62,72 @@ def _get_usage_metadata_from_generation_info(
return None
def _parse_json_string(
json_string: str, raw_tool_call: dict[str, Any], skip: bool
) -> Any:
"""Attempt to parse a JSON string for tool calling.
Args:
json_string: JSON string to parse.
skip: Whether to ignore parsing errors and return the value anyways.
raw_tool_call: Raw tool call to include in error message.
Returns:
The parsed JSON string.
Raises:
OutputParserException: If the JSON string wrong invalid and skip=False.
"""
try:
return json.loads(json_string)
except json.JSONDecodeError as e:
if skip:
return json_string
msg = (
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. "
f"Received JSONDecodeError {e}"
)
raise OutputParserException(msg) from e
except TypeError as e:
if skip:
return json_string
msg = (
f"Function {raw_tool_call['function']['name']} arguments:\n\n"
f"{raw_tool_call['function']['arguments']}\n\nare not a string or a "
f"dictionary. Received TypeError {e}"
)
raise OutputParserException(msg) from e
def _parse_arguments_from_tool_call(
raw_tool_call: dict[str, Any],
) -> Optional[dict[str, Any]]:
"""Parse arguments by trying to parse any shallowly nested string-encoded JSON.
Band-aid fix for issue in Ollama with inconsistent tool call argument structure.
Should be removed/changed if fixed upstream.
See https://github.com/ollama/ollama/issues/6155
"""
if "function" not in raw_tool_call:
return None
arguments = raw_tool_call["function"]["arguments"]
parsed_arguments = {}
if isinstance(arguments, dict):
for key, value in arguments.items():
if isinstance(value, str):
parsed_arguments[key] = _parse_json_string(
value, skip=True, raw_tool_call=raw_tool_call
)
else:
parsed_arguments[key] = value
else:
parsed_arguments = _parse_json_string(
arguments, skip=False, raw_tool_call=raw_tool_call
)
return parsed_arguments
def _get_tool_calls_from_response(
response: Mapping[str, Any],
) -> List[ToolCall]:
@@ -72,7 +140,7 @@ def _get_tool_calls_from_response(
tool_call(
id=str(uuid4()),
name=tc["function"]["name"],
args=tc["function"]["arguments"],
args=_parse_arguments_from_tool_call(tc) or {},
)
)
return tool_calls