mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-08 18:19:21 +00:00
Compare commits
1 Commits
langchain-
...
sr/improve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3d26b0990 |
@@ -797,8 +797,27 @@ def create_agent( # noqa: PLR0915
|
||||
provider_strategy_binding = ProviderStrategyBinding.from_schema_spec(
|
||||
effective_response_format.schema_spec
|
||||
)
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return {"messages": [output], "structured_response": structured_response}
|
||||
|
||||
# Retry logic for provider strategy
|
||||
last_exception: Exception | None = None
|
||||
for attempt in range(effective_response_format.max_retries + 1):
|
||||
try:
|
||||
structured_response = provider_strategy_binding.parse(output)
|
||||
return {"messages": [output], "structured_response": structured_response}
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
if attempt >= effective_response_format.max_retries:
|
||||
break
|
||||
|
||||
# Handle failure after retries exhausted
|
||||
if effective_response_format.on_failure == "raise":
|
||||
raise last_exception # type: ignore[misc]
|
||||
elif callable(effective_response_format.on_failure):
|
||||
msg = effective_response_format.on_failure(last_exception) # type: ignore[arg-type]
|
||||
raise StructuredOutputValidationError(
|
||||
effective_response_format.schema_spec.name, Exception(msg)
|
||||
) from last_exception
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
# Handle structured output with tool strategy
|
||||
@@ -834,37 +853,47 @@ def create_agent( # noqa: PLR0915
|
||||
]
|
||||
return {"messages": [output, *tool_messages]}
|
||||
|
||||
# Handle single structured output
|
||||
# Handle single structured output with retry
|
||||
tool_call = structured_tool_calls[0]
|
||||
try:
|
||||
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
structured_tool_binding = structured_output_tools[tool_call["name"]]
|
||||
|
||||
tool_message_content = (
|
||||
effective_response_format.tool_message_content
|
||||
if effective_response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
# Retry logic for tool strategy
|
||||
last_exception: Exception | None = None
|
||||
for attempt in range(effective_response_format.max_retries + 1):
|
||||
try:
|
||||
structured_response = structured_tool_binding.parse(tool_call["args"])
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
output,
|
||||
ToolMessage(
|
||||
content=tool_message_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
"structured_response": structured_response,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
exception = StructuredOutputValidationError(tool_call["name"], exc)
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, effective_response_format
|
||||
)
|
||||
if not should_retry:
|
||||
raise exception
|
||||
tool_message_content = (
|
||||
effective_response_format.tool_message_content
|
||||
if effective_response_format.tool_message_content
|
||||
else f"Returning structured response: {structured_response}"
|
||||
)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
output,
|
||||
ToolMessage(
|
||||
content=tool_message_content,
|
||||
tool_call_id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
),
|
||||
],
|
||||
"structured_response": structured_response,
|
||||
}
|
||||
except Exception as exc: # noqa: BLE001
|
||||
last_exception = exc
|
||||
if attempt >= effective_response_format.max_retries:
|
||||
break
|
||||
|
||||
# Handle failure after retries exhausted
|
||||
exception = StructuredOutputValidationError(tool_call["name"], last_exception) # type: ignore[arg-type]
|
||||
|
||||
# Check if we should send error to model for self-correction
|
||||
should_retry, error_message = _handle_structured_output_error(
|
||||
exception, effective_response_format
|
||||
)
|
||||
if should_retry:
|
||||
# Model self-correction mode
|
||||
return {
|
||||
"messages": [
|
||||
output,
|
||||
@@ -876,6 +905,17 @@ def create_agent( # noqa: PLR0915
|
||||
],
|
||||
}
|
||||
|
||||
# Apply on_failure behavior
|
||||
if effective_response_format.on_failure == "raise":
|
||||
raise exception
|
||||
elif callable(effective_response_format.on_failure):
|
||||
msg = effective_response_format.on_failure(last_exception) # type: ignore[arg-type]
|
||||
raise StructuredOutputValidationError(
|
||||
tool_call["name"], Exception(msg)
|
||||
) from last_exception
|
||||
else:
|
||||
raise exception
|
||||
|
||||
return {"messages": [output]}
|
||||
|
||||
def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
|
||||
|
||||
@@ -199,6 +199,16 @@ class ToolStrategy(Generic[SchemaT]):
|
||||
- `False`: No retry, let exceptions propagate
|
||||
"""
|
||||
|
||||
max_retries: int
|
||||
"""Maximum number of retry attempts for parsing failures (default: 0)."""
|
||||
|
||||
on_failure: Literal["raise"] | Callable[[Exception], str]
|
||||
"""Behavior when retries are exhausted.
|
||||
|
||||
- `"raise"`: Raise the exception (default)
|
||||
- `Callable[[Exception], str]`: Custom error message function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT],
|
||||
@@ -209,15 +219,23 @@ class ToolStrategy(Generic[SchemaT]):
|
||||
| type[Exception]
|
||||
| tuple[type[Exception], ...]
|
||||
| Callable[[Exception], str] = True,
|
||||
max_retries: int = 0,
|
||||
on_failure: Literal["raise"] | Callable[[Exception], str] = "raise",
|
||||
) -> None:
|
||||
"""Initialize `ToolStrategy`.
|
||||
|
||||
Initialize `ToolStrategy` with schemas, tool message content, and error handling
|
||||
strategy.
|
||||
Args:
|
||||
schema: The schema for structured output.
|
||||
tool_message_content: Custom content for tool response messages.
|
||||
handle_errors: Error handling strategy for customizing error messages.
|
||||
max_retries: Maximum number of retry attempts (0 = no retry).
|
||||
on_failure: Behavior when retries exhausted (`"raise"` or custom function).
|
||||
"""
|
||||
self.schema = schema
|
||||
self.tool_message_content = tool_message_content
|
||||
self.handle_errors = handle_errors
|
||||
self.max_retries = max_retries
|
||||
self.on_failure = on_failure
|
||||
|
||||
def _iter_variants(schema: Any) -> Iterable[Any]:
|
||||
"""Yield leaf variants from Union and JSON Schema oneOf."""
|
||||
@@ -246,13 +264,34 @@ class ProviderStrategy(Generic[SchemaT]):
|
||||
schema_spec: _SchemaSpec[SchemaT]
|
||||
"""Schema spec for native mode."""
|
||||
|
||||
max_retries: int
|
||||
"""Maximum number of retry attempts for parsing failures (default: 0)."""
|
||||
|
||||
on_failure: Literal["raise"] | Callable[[Exception], str]
|
||||
"""Behavior when retries are exhausted.
|
||||
|
||||
- `"raise"`: Raise the exception (default)
|
||||
- `Callable[[Exception], str]`: Custom error message function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: type[SchemaT],
|
||||
*,
|
||||
max_retries: int = 0,
|
||||
on_failure: Literal["raise"] | Callable[[Exception], str] = "raise",
|
||||
) -> None:
|
||||
"""Initialize ProviderStrategy with schema."""
|
||||
"""Initialize ProviderStrategy with schema.
|
||||
|
||||
Args:
|
||||
schema: The schema for structured output.
|
||||
max_retries: Maximum number of retry attempts (0 = no retry).
|
||||
on_failure: Behavior when retries exhausted (`"raise"` or custom function).
|
||||
"""
|
||||
self.schema = schema
|
||||
self.schema_spec = _SchemaSpec(schema)
|
||||
self.max_retries = max_retries
|
||||
self.on_failure = on_failure
|
||||
|
||||
def to_model_kwargs(self) -> dict[str, Any]:
|
||||
"""Convert to kwargs to bind to a model to force structured output."""
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Unit tests for structured output retry functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from langchain.agents.structured_output import (
|
||||
OutputToolBinding,
|
||||
ProviderStrategy,
|
||||
ProviderStrategyBinding,
|
||||
StructuredOutputValidationError,
|
||||
ToolStrategy,
|
||||
_SchemaSpec,
|
||||
)
|
||||
|
||||
|
||||
class ResponseSchema(BaseModel):
|
||||
"""Test schema for structured output."""
|
||||
|
||||
value: int
|
||||
|
||||
|
||||
class TestProviderStrategyRetry:
|
||||
"""Test retry functionality for ProviderStrategy."""
|
||||
|
||||
def test_provider_strategy_default_no_retry(self):
|
||||
"""Test that ProviderStrategy defaults to no retry."""
|
||||
strategy = ProviderStrategy(ResponseSchema)
|
||||
assert strategy.max_retries == 0
|
||||
assert strategy.on_failure == "raise"
|
||||
|
||||
def test_provider_strategy_with_max_retries(self):
|
||||
"""Test ProviderStrategy with max_retries configured."""
|
||||
strategy = ProviderStrategy(ResponseSchema, max_retries=3)
|
||||
assert strategy.max_retries == 3
|
||||
|
||||
def test_provider_strategy_with_custom_on_failure(self):
|
||||
"""Test ProviderStrategy with custom on_failure function."""
|
||||
|
||||
def custom_failure(exc: Exception) -> str:
|
||||
return f"Custom error: {exc}"
|
||||
|
||||
strategy = ProviderStrategy(ResponseSchema, on_failure=custom_failure)
|
||||
assert callable(strategy.on_failure)
|
||||
|
||||
def test_provider_binding_retry_succeeds_on_second_attempt(self):
|
||||
"""Test that retry succeeds when second attempt works."""
|
||||
binding = ProviderStrategyBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
|
||||
|
||||
# Mock parse to fail once then succeed
|
||||
original_parse = binding.parse
|
||||
call_count = 0
|
||||
|
||||
def mock_parse(msg):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise ValueError("First attempt fails")
|
||||
return ResponseSchema(value=42)
|
||||
|
||||
binding.parse = mock_parse # type: ignore[method-assign]
|
||||
|
||||
# Simulate retry logic (would be in factory.py)
|
||||
max_retries = 2
|
||||
last_exception = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
result = binding.parse(Mock())
|
||||
assert result.value == 42
|
||||
assert call_count == 2 # Failed once, succeeded on retry
|
||||
break
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
if attempt >= max_retries:
|
||||
raise
|
||||
|
||||
def test_provider_binding_retry_exhausted(self):
|
||||
"""Test that retries are exhausted and exception is raised."""
|
||||
binding = ProviderStrategyBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
|
||||
|
||||
# Mock parse to always fail
|
||||
def mock_parse(msg):
|
||||
raise ValueError("Always fails")
|
||||
|
||||
binding.parse = mock_parse # type: ignore[method-assign]
|
||||
|
||||
# Simulate retry logic
|
||||
max_retries = 2
|
||||
last_exception = None
|
||||
attempts = 0
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
binding.parse(Mock())
|
||||
except Exception as exc:
|
||||
attempts += 1
|
||||
last_exception = exc
|
||||
if attempt >= max_retries:
|
||||
break
|
||||
|
||||
assert attempts == 3 # Initial + 2 retries
|
||||
assert last_exception is not None
|
||||
assert "Always fails" in str(last_exception)
|
||||
|
||||
|
||||
class TestToolStrategyRetry:
|
||||
"""Test retry functionality for ToolStrategy."""
|
||||
|
||||
def test_tool_strategy_default_no_retry(self):
|
||||
"""Test that ToolStrategy defaults to no retry."""
|
||||
strategy = ToolStrategy(ResponseSchema)
|
||||
assert strategy.max_retries == 0
|
||||
assert strategy.on_failure == "raise"
|
||||
|
||||
def test_tool_strategy_with_max_retries(self):
|
||||
"""Test ToolStrategy with max_retries configured."""
|
||||
strategy = ToolStrategy(ResponseSchema, max_retries=3)
|
||||
assert strategy.max_retries == 3
|
||||
|
||||
def test_tool_strategy_retains_handle_errors(self):
|
||||
"""Test that ToolStrategy retains handle_errors for backward compatibility."""
|
||||
strategy = ToolStrategy(ResponseSchema, handle_errors=False, max_retries=2)
|
||||
assert strategy.handle_errors is False
|
||||
assert strategy.max_retries == 2
|
||||
|
||||
def test_tool_strategy_custom_error_message(self):
|
||||
"""Test ToolStrategy with custom error message function."""
|
||||
|
||||
def custom_error(exc: Exception) -> str:
|
||||
return f"Validation failed: {exc}"
|
||||
|
||||
strategy = ToolStrategy(ResponseSchema, handle_errors=custom_error)
|
||||
assert callable(strategy.handle_errors)
|
||||
|
||||
def test_tool_binding_retry_succeeds_on_third_attempt(self):
|
||||
"""Test that retry succeeds when third attempt works."""
|
||||
binding = OutputToolBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
|
||||
|
||||
# Mock parse to fail twice then succeed
|
||||
call_count = 0
|
||||
|
||||
def mock_parse(args):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValidationError.from_exception_data(
|
||||
"ResponseSchema",
|
||||
[{"type": "missing", "loc": ("value",), "msg": "Field required"}],
|
||||
)
|
||||
return ResponseSchema(value=100)
|
||||
|
||||
binding.parse = mock_parse # type: ignore[method-assign]
|
||||
|
||||
# Simulate retry logic
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
result = binding.parse({"value": 100})
|
||||
assert result.value == 100
|
||||
assert call_count == 3 # Failed twice, succeeded on third
|
||||
break
|
||||
except Exception:
|
||||
if attempt >= max_retries:
|
||||
raise
|
||||
|
||||
def test_tool_binding_retry_with_on_failure_callable(self):
|
||||
"""Test custom on_failure function is applied correctly."""
|
||||
|
||||
def custom_failure(exc: Exception) -> str:
|
||||
return f"Custom message: {type(exc).__name__}"
|
||||
|
||||
binding = OutputToolBinding.from_schema_spec(_SchemaSpec(ResponseSchema))
|
||||
|
||||
def mock_parse(args):
|
||||
raise ValidationError.from_exception_data(
|
||||
"ResponseSchema",
|
||||
[{"type": "int_type", "loc": ("value",), "msg": "Input should be an integer"}],
|
||||
)
|
||||
|
||||
binding.parse = mock_parse # type: ignore[method-assign]
|
||||
|
||||
# Simulate retry with custom failure
|
||||
max_retries = 1
|
||||
last_exception = None
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
binding.parse({"value": "not_an_int"})
|
||||
except Exception as exc:
|
||||
last_exception = exc
|
||||
if attempt >= max_retries:
|
||||
break
|
||||
|
||||
# Apply custom failure handler
|
||||
error_msg = custom_failure(last_exception) # type: ignore[arg-type]
|
||||
assert "Custom message: ValidationError" in error_msg
|
||||
|
||||
|
||||
class TestRetryBehavior:
|
||||
"""Test overall retry behavior."""
|
||||
|
||||
def test_max_retries_zero_means_no_retry(self):
|
||||
"""Test that max_retries=0 means single attempt only."""
|
||||
strategy = ToolStrategy(ResponseSchema, max_retries=0)
|
||||
|
||||
# With max_retries=0, should only attempt once
|
||||
attempts = 0
|
||||
for attempt in range(strategy.max_retries + 1):
|
||||
attempts += 1
|
||||
|
||||
assert attempts == 1
|
||||
|
||||
def test_max_retries_three_means_four_attempts(self):
|
||||
"""Test that max_retries=3 means 4 total attempts."""
|
||||
strategy = ProviderStrategy(ResponseSchema, max_retries=3)
|
||||
|
||||
# With max_retries=3, should attempt 4 times total
|
||||
attempts = 0
|
||||
for attempt in range(strategy.max_retries + 1):
|
||||
attempts += 1
|
||||
|
||||
assert attempts == 4
|
||||
|
||||
def test_on_failure_raise_is_default(self):
|
||||
"""Test that on_failure defaults to 'raise'."""
|
||||
tool_strategy = ToolStrategy(ResponseSchema)
|
||||
provider_strategy = ProviderStrategy(ResponseSchema)
|
||||
|
||||
assert tool_strategy.on_failure == "raise"
|
||||
assert provider_strategy.on_failure == "raise"
|
||||
2
libs/langchain_v1/uv.lock
generated
2
libs/langchain_v1/uv.lock
generated
@@ -1743,7 +1743,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.0.0"
|
||||
version = "1.0.1"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
Reference in New Issue
Block a user