From 607c60a594329f5cfb08b3e5e40fc21b38d30cf3 Mon Sep 17 00:00:00 2001 From: TheDannyG Date: Wed, 27 Nov 2024 10:32:02 -0500 Subject: [PATCH] partners/ollama: fix tool calling with nested schemas (#28225) ## Description This PR addresses the following: **Fixes Issue #25343:** - Adds additional logic to parse shallowly nested JSON-encoded strings in tool call arguments, allowing for proper parsing of responses like that of Llama3.1 and 3.2 with nested schemas. **Adds Integration Test for Fix:** - Adds a Ollama specific integration test to ensure the issue is resolved and to prevent regressions in the future. **Fixes Failing Integration Tests:** - Fixes failing integration tests (even prior to changes) caused by `llama3-groq-tool-use` model. Previously, tests`test_structured_output_async` and `test_structured_output_optional_param` failed due to the model not issuing a tool call in the response. Resolved by switching to `llama3.1`. ## Issue Fixes #25343. ## Dependencies No dependencies. ____ Done in collaboration with @ishaan-upadhyay @mirajismail @ZackSteine. --- .../ollama/langchain_ollama/chat_models.py | 70 ++++++++++++++++++- .../chat_models/test_chat_models.py | 41 +++++++++++ .../test_chat_models_standard.py} | 4 +- 3 files changed, 112 insertions(+), 3 deletions(-) create mode 100644 libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py rename libs/partners/ollama/tests/integration_tests/{test_chat_models.py => chat_models/test_chat_models_standard.py} (90%) diff --git a/libs/partners/ollama/langchain_ollama/chat_models.py b/libs/partners/ollama/langchain_ollama/chat_models.py index 9ec7a77a35d..4f074206c8f 100644 --- a/libs/partners/ollama/langchain_ollama/chat_models.py +++ b/libs/partners/ollama/langchain_ollama/chat_models.py @@ -1,5 +1,6 @@ """Ollama chat models.""" +import json from typing import ( Any, AsyncIterator, @@ -21,6 +22,7 @@ from langchain_core.callbacks import ( CallbackManagerForLLMRun, ) from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun +from langchain_core.exceptions import OutputParserException from langchain_core.language_models import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams from langchain_core.messages import ( @@ -60,6 +62,72 @@ def _get_usage_metadata_from_generation_info( return None +def _parse_json_string( + json_string: str, raw_tool_call: dict[str, Any], skip: bool +) -> Any: + """Attempt to parse a JSON string for tool calling. + + Args: + json_string: JSON string to parse. + skip: Whether to ignore parsing errors and return the value anyways. + raw_tool_call: Raw tool call to include in error message. + + Returns: + The parsed JSON string. + + Raises: + OutputParserException: If the JSON string wrong invalid and skip=False. + """ + try: + return json.loads(json_string) + except json.JSONDecodeError as e: + if skip: + return json_string + msg = ( + f"Function {raw_tool_call['function']['name']} arguments:\n\n" + f"{raw_tool_call['function']['arguments']}\n\nare not valid JSON. " + f"Received JSONDecodeError {e}" + ) + raise OutputParserException(msg) from e + except TypeError as e: + if skip: + return json_string + msg = ( + f"Function {raw_tool_call['function']['name']} arguments:\n\n" + f"{raw_tool_call['function']['arguments']}\n\nare not a string or a " + f"dictionary. Received TypeError {e}" + ) + raise OutputParserException(msg) from e + + +def _parse_arguments_from_tool_call( + raw_tool_call: dict[str, Any], +) -> Optional[dict[str, Any]]: + """Parse arguments by trying to parse any shallowly nested string-encoded JSON. + + Band-aid fix for issue in Ollama with inconsistent tool call argument structure. + Should be removed/changed if fixed upstream. + See https://github.com/ollama/ollama/issues/6155 + """ + if "function" not in raw_tool_call: + return None + arguments = raw_tool_call["function"]["arguments"] + parsed_arguments = {} + if isinstance(arguments, dict): + for key, value in arguments.items(): + if isinstance(value, str): + parsed_arguments[key] = _parse_json_string( + value, skip=True, raw_tool_call=raw_tool_call + ) + else: + parsed_arguments[key] = value + else: + parsed_arguments = _parse_json_string( + arguments, skip=False, raw_tool_call=raw_tool_call + ) + return parsed_arguments + + def _get_tool_calls_from_response( response: Mapping[str, Any], ) -> List[ToolCall]: @@ -72,7 +140,7 @@ def _get_tool_calls_from_response( tool_call( id=str(uuid4()), name=tc["function"]["name"], - args=tc["function"]["arguments"], + args=_parse_arguments_from_tool_call(tc) or {}, ) ) return tool_calls diff --git a/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py new file mode 100644 index 00000000000..837df8c5164 --- /dev/null +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py @@ -0,0 +1,41 @@ +"""Ollama specific chat model integration tests""" + +from typing import List, Optional + +import pytest +from pydantic import BaseModel, Field + +from langchain_ollama import ChatOllama + + +@pytest.mark.parametrize(("model"), [("llama3.1")]) +def test_structured_output_deeply_nested(model: str) -> None: + """Test to verify structured output with a nested objects.""" + llm = ChatOllama(model=model, temperature=0) + + class Person(BaseModel): + """Information about a person.""" + + name: Optional[str] = Field(default=None, description="The name of the person") + hair_color: Optional[str] = Field( + default=None, description="The color of the person's hair if known" + ) + height_in_meters: Optional[str] = Field( + default=None, description="Height measured in meters" + ) + + class Data(BaseModel): + """Extracted data about people.""" + + people: List[Person] + + chat = llm.with_structured_output(Data) # type: ignore[arg-type] + text = ( + "Alan Smith is 6 feet tall and has blond hair." + "Alan Poe is 3 feet tall and has grey hair." + ) + result = chat.invoke(text) + assert isinstance(result, Data) + + for chunk in chat.stream(text): + assert isinstance(chunk, Data) diff --git a/libs/partners/ollama/tests/integration_tests/test_chat_models.py b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py similarity index 90% rename from libs/partners/ollama/tests/integration_tests/test_chat_models.py rename to libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py index 9133106cae7..476640ddd79 100644 --- a/libs/partners/ollama/tests/integration_tests/test_chat_models.py +++ b/libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models_standard.py @@ -1,4 +1,4 @@ -"""Test chat model integration.""" +"""Test chat model integration using standard integration tests.""" from typing import Type @@ -16,7 +16,7 @@ class TestChatOllama(ChatModelIntegrationTests): @property def chat_model_params(self) -> dict: - return {"model": "llama3-groq-tool-use"} + return {"model": "llama3.1"} @property def supports_image_inputs(self) -> bool: