core[patch]: Support injected tool args that are arbitrary types (#27045)

This adds support for inject tool args that are arbitrary types when
used with pydantic 2.

We'll need to add similar logic on the v1 path, and potentially mirror
the config from the original model when we're doing the subset.
This commit is contained in:
Eugene Yurtsev 2024-10-02 12:50:58 -04:00 committed by GitHub
parent e806e9de38
commit 74bf620e97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 2 deletions

View File

@ -266,7 +266,7 @@ def _create_subset_model_v2(
fn_description: Optional[str] = None, fn_description: Optional[str] = None,
) -> type[pydantic.BaseModel]: ) -> type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields.""" """Create a pydantic model with a subset of the model fields."""
from pydantic import create_model from pydantic import ConfigDict, create_model
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
descriptions_ = descriptions or {} descriptions_ = descriptions or {}
@ -278,7 +278,10 @@ def _create_subset_model_v2(
if field.metadata: if field.metadata:
field_info.metadata = field.metadata field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info) fields[field_name] = (field.annotation, field_info)
rtn = create_model(name, **fields) # type: ignore
rtn = create_model( # type: ignore
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
)
# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
# This is done to preserve __annotations__ when working with pydantic 2.x # This is done to preserve __annotations__ when working with pydantic 2.x

View File

@ -2090,3 +2090,18 @@ def test_structured_tool_direct_init() -> None:
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
assert tool.invoke("hello") == "hello" assert tool.invoke("hello") == "hello"
def test_injected_arg_with_complex_type() -> None:
"""Test that an injected tool arg can be a complex type."""
class Foo:
def __init__(self) -> None:
self.value = "bar"
@tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
"""Tool that has an injected tool arg."""
return foo.value
assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore