From de56c31672fa5151043e76beca3f226ae84969b0 Mon Sep 17 00:00:00 2001 From: Ahmed Tammaa Date: Mon, 21 Apr 2025 16:06:18 +0200 Subject: [PATCH] core: Improve OutputParser error messaging when model output is truncated (max_tokens) (#30936) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses #30158 When using the output parser—either in a chain or standalone—hitting max_tokens triggers a misleading “missing variable” error instead of indicating the output was truncated. This subtle bug often surfaces with Anthropic models. --------- Co-authored-by: Chester Curme --- .../output_parsers/openai_tools.py | 19 +++++++++++++++ .../output_parsers/test_openai_tools.py | 23 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index 254d11eca1f..b4f845bdcad 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -2,6 +2,7 @@ import copy import json +import logging from json import JSONDecodeError from typing import Annotated, Any, Optional @@ -16,6 +17,8 @@ from langchain_core.outputs import ChatGeneration, Generation from langchain_core.utils.json import parse_partial_json from langchain_core.utils.pydantic import TypeBaseModel +logger = logging.getLogger(__name__) + def parse_tool_call( raw_tool_call: dict[str, Any], @@ -250,6 +253,14 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): return parsed_result +# Common cause of ValidationError is truncated output due to max_tokens. +_MAX_TOKENS_ERROR = ( + "Output parser received a `max_tokens` stop reason. " + "The output is likely incomplete—please increase `max_tokens` " + "or shorten your prompt." +) + + class PydanticToolsParser(JsonOutputToolsParser): """Parse tools from OpenAI response.""" @@ -296,6 +307,14 @@ class PydanticToolsParser(JsonOutputToolsParser): except (ValidationError, ValueError): if partial: continue + has_max_tokens_stop_reason = any( + generation.message.response_metadata.get("stop_reason") + == "max_tokens" + for generation in result + if isinstance(generation, ChatGeneration) + ) + if has_max_tokens_stop_reason: + logger.exception(_MAX_TOKENS_ERROR) raise if self.first_tool_only: return pydantic_objects[0] if pydantic_objects else None diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index c63dd693d74..992e5c48c55 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Iterator from typing import Any import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ValidationError from langchain_core.messages import ( AIMessage, @@ -635,3 +635,24 @@ def test_parse_with_different_pydantic_1_proper() -> None: forecast="Sunny", ) ] + + +def test_max_tokens_error(caplog: Any) -> None: + parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True) + input = AIMessage( + content="", + tool_calls=[ + { + "id": "call_OwL7f5PE", + "name": "NameCollector", + "args": {"names": ["suz", "jerm"]}, + } + ], + response_metadata={"stop_reason": "max_tokens"}, + ) + with pytest.raises(ValidationError): + _ = parser.invoke(input) + assert any( + "`max_tokens` stop reason" in msg and record.levelname == "ERROR" + for record, msg in zip(caplog.records, caplog.messages) + )