mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-05 07:08:03 +00:00
core: Improve OutputParser error messaging when model output is truncated (max_tokens) (#30936)
Addresses #30158 When using the output parser—either in a chain or standalone—hitting max_tokens triggers a misleading “missing variable” error instead of indicating the output was truncated. This subtle bug often surfaces with Anthropic models. --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
335f089d6a
commit
de56c31672
libs/core
@ -2,6 +2,7 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
@ -16,6 +17,8 @@ from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
from langchain_core.utils.pydantic import TypeBaseModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_tool_call(
|
||||
raw_tool_call: dict[str, Any],
|
||||
@ -250,6 +253,14 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
return parsed_result
|
||||
|
||||
|
||||
# Common cause of ValidationError is truncated output due to max_tokens.
|
||||
_MAX_TOKENS_ERROR = (
|
||||
"Output parser received a `max_tokens` stop reason. "
|
||||
"The output is likely incomplete—please increase `max_tokens` "
|
||||
"or shorten your prompt."
|
||||
)
|
||||
|
||||
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
@ -296,6 +307,14 @@ class PydanticToolsParser(JsonOutputToolsParser):
|
||||
except (ValidationError, ValueError):
|
||||
if partial:
|
||||
continue
|
||||
has_max_tokens_stop_reason = any(
|
||||
generation.message.response_metadata.get("stop_reason")
|
||||
== "max_tokens"
|
||||
for generation in result
|
||||
if isinstance(generation, ChatGeneration)
|
||||
)
|
||||
if has_max_tokens_stop_reason:
|
||||
logger.exception(_MAX_TOKENS_ERROR)
|
||||
raise
|
||||
if self.first_tool_only:
|
||||
return pydantic_objects[0] if pydantic_objects else None
|
||||
|
@ -2,7 +2,7 @@ from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
@ -635,3 +635,24 @@ def test_parse_with_different_pydantic_1_proper() -> None:
|
||||
forecast="Sunny",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def test_max_tokens_error(caplog: Any) -> None:
|
||||
parser = PydanticToolsParser(tools=[NameCollector], first_tool_only=True)
|
||||
input = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_OwL7f5PE",
|
||||
"name": "NameCollector",
|
||||
"args": {"names": ["suz", "jerm"]},
|
||||
}
|
||||
],
|
||||
response_metadata={"stop_reason": "max_tokens"},
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
_ = parser.invoke(input)
|
||||
assert any(
|
||||
"`max_tokens` stop reason" in msg and record.levelname == "ERROR"
|
||||
for record, msg in zip(caplog.records, caplog.messages)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user