mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 02:06:33 +00:00
Merge 300299d0d0
into 3a487bf720
This commit is contained in:
commit
88924015ea
@ -687,9 +687,7 @@ class ChildTool(BaseTool):
|
|||||||
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
f"args_schema must be a Pydantic BaseModel, got {self.args_schema}"
|
||||||
)
|
)
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
return {
|
return {k: getattr(result, k) for k, v in result_dict.items()}
|
||||||
k: getattr(result, k) for k, v in result_dict.items() if k in tool_input
|
|
||||||
}
|
|
||||||
return tool_input
|
return tool_input
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
|
@ -21,7 +21,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError, model_validator
|
||||||
from pydantic.v1 import BaseModel as BaseModelV1
|
from pydantic.v1 import BaseModel as BaseModelV1
|
||||||
from pydantic.v1 import ValidationError as ValidationErrorV1
|
from pydantic.v1 import ValidationError as ValidationErrorV1
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
@ -634,7 +634,7 @@ def test_named_tool_decorator_return_direct() -> None:
|
|||||||
"""Test functionality when arguments and return direct are provided as input."""
|
"""Test functionality when arguments and return direct are provided as input."""
|
||||||
|
|
||||||
@tool("search", return_direct=True)
|
@tool("search", return_direct=True)
|
||||||
def search_api(query: str, *args: Any) -> str:
|
def search_api(query: str) -> str:
|
||||||
"""Search the API for the query."""
|
"""Search the API for the query."""
|
||||||
return "API result"
|
return "API result"
|
||||||
|
|
||||||
@ -2547,6 +2547,38 @@ def test_tool_decorator_description() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_args_schema_with_pydantic_validator() -> None:
|
||||||
|
class NestedArgsSchema(BaseModel):
|
||||||
|
y: int
|
||||||
|
|
||||||
|
class ArgsSchema(BaseModel):
|
||||||
|
x: NestedArgsSchema
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
def wrap_in_x(cls, data: Any) -> Any: # noqa: N805
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return {"x": data}
|
||||||
|
|
||||||
|
if "x" not in data:
|
||||||
|
return {"x": data}
|
||||||
|
return data
|
||||||
|
|
||||||
|
@tool(args_schema=ArgsSchema)
|
||||||
|
def foo(**args: Any) -> ArgsSchema:
|
||||||
|
"""Bar."""
|
||||||
|
return ArgsSchema.model_validate(args)
|
||||||
|
|
||||||
|
# Test case where validator is identity function
|
||||||
|
valid_inputs = {"x": {"y": 5}}
|
||||||
|
assert foo.invoke(valid_inputs) == ArgsSchema.model_validate(valid_inputs)
|
||||||
|
|
||||||
|
# Test case where validator wraps input in "x"
|
||||||
|
invalid_inputs = {"y": 5}
|
||||||
|
assert foo.invoke(invalid_inputs) == ArgsSchema.model_validate(
|
||||||
|
{"x": invalid_inputs}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_title_property_preserved() -> None:
|
def test_title_property_preserved() -> None:
|
||||||
"""Test that the title property is preserved when generating schema.
|
"""Test that the title property is preserved when generating schema.
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user