mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 22:05:29 +00:00
core[patch]: Fix injected args in tool signature (#25991)
- Fix injected args in tool signature - Fix another unit test that was using the wrong namespace import in pydantic
This commit is contained in:
@@ -273,6 +273,17 @@ def _create_subset_model_v2(
|
|||||||
fields[field_name] = (field.annotation, field_info)
|
fields[field_name] = (field.annotation, field_info)
|
||||||
rtn = create_model(name, **fields) # type: ignore
|
rtn = create_model(name, **fields) # type: ignore
|
||||||
|
|
||||||
|
# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
|
||||||
|
# This is done to preserve __annotations__ when working with pydantic 2.x
|
||||||
|
# and using the Annotated type with TypedDict.
|
||||||
|
# Comment out the following line, to trigger the relevant test case.
|
||||||
|
selected_annotations = [
|
||||||
|
(name, annotation)
|
||||||
|
for name, annotation in model.__annotations__.items()
|
||||||
|
if name in field_names
|
||||||
|
]
|
||||||
|
|
||||||
|
rtn.__annotations__ = dict(selected_annotations)
|
||||||
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
|
||||||
return rtn
|
return rtn
|
||||||
|
|
||||||
|
@@ -1756,13 +1756,17 @@ def test__get_all_basemodel_annotations_v2(use_v1_namespace: bool) -> None:
|
|||||||
A = TypeVar("A")
|
A = TypeVar("A")
|
||||||
|
|
||||||
if use_v1_namespace:
|
if use_v1_namespace:
|
||||||
|
from pydantic.v1 import BaseModel as BM1
|
||||||
|
|
||||||
class ModelA(BaseModel, Generic[A], extra="allow"):
|
class ModelA(BM1, Generic[A], extra="allow"):
|
||||||
a: A
|
a: A
|
||||||
else:
|
else:
|
||||||
|
from pydantic import BaseModel as BM2
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
class ModelA(BaseModelProper, Generic[A], extra="allow"): # type: ignore[no-redef]
|
class ModelA(BM2, Generic[A], extra="allow"): # type: ignore[no-redef]
|
||||||
a: A
|
a: A
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
class ModelB(ModelA[str]):
|
class ModelB(ModelA[str]):
|
||||||
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
b: Annotated[ModelA[Dict[str, Any]], "foo"]
|
||||||
@@ -1871,6 +1875,26 @@ def test__get_all_basemodel_annotations_v1() -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_annotations_preserved() -> None:
|
||||||
|
"""Test that annotations are preserved when creating a tool."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def my_tool(val: int, other_val: Annotated[dict, "my annotation"]) -> str:
|
||||||
|
"""Tool docstring."""
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
schema = my_tool.get_input_schema() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
func = my_tool.func # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
expected_type_hints = {
|
||||||
|
name: hint
|
||||||
|
for name, hint in func.__annotations__.items()
|
||||||
|
if name in inspect.signature(func).parameters
|
||||||
|
}
|
||||||
|
assert schema.__annotations__ == expected_type_hints
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Testing pydantic v2.")
|
||||||
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
def test_tool_args_schema_pydantic_v2_with_metadata() -> None:
|
||||||
from pydantic import BaseModel as BaseModelV2
|
from pydantic import BaseModel as BaseModelV2
|
||||||
|
Reference in New Issue
Block a user