diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 1352bcfafff..93375e09f34 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -266,7 +266,7 @@ def _create_subset_model_v2( fn_description: Optional[str] = None, ) -> type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" - from pydantic import create_model + from pydantic import ConfigDict, create_model from pydantic.fields import FieldInfo descriptions_ = descriptions or {} @@ -278,7 +278,10 @@ def _create_subset_model_v2( if field.metadata: field_info.metadata = field.metadata fields[field_name] = (field.annotation, field_info) - rtn = create_model(name, **fields) # type: ignore + + rtn = create_model( # type: ignore + name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) + ) # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. # This is done to preserve __annotations__ when working with pydantic 2.x diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index a61cead53c2..3eae40ede1e 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2090,3 +2090,18 @@ def test_structured_tool_direct_init() -> None: with pytest.raises(NotImplementedError): assert tool.invoke("hello") == "hello" + + +def test_injected_arg_with_complex_type() -> None: + """Test that an injected tool arg can be a complex type.""" + + class Foo: + def __init__(self) -> None: + self.value = "bar" + + @tool + def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: + """Tool that has an injected tool arg.""" + return foo.value + + assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore