diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index f30ef85fdb5..4ac5ceba508 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -5,6 +5,7 @@ from __future__ import annotations import functools import inspect import json +import typing import warnings from abc import ABC, abstractmethod from inspect import signature @@ -80,7 +81,7 @@ class SchemaAnnotationError(TypeError): def _is_annotated_type(typ: type[Any]) -> bool: - return get_origin(typ) is Annotated + return get_origin(typ) is typing.Annotated def _get_annotation_description(arg_type: type) -> str | None: @@ -143,11 +144,7 @@ def _infer_arg_descriptions( error_on_invalid_docstring: bool = False, ) -> tuple[str, dict]: """Infer argument descriptions from a function's docstring.""" - if hasattr(inspect, "get_annotations"): - # This is for python < 3.10 - annotations = inspect.get_annotations(fn) - else: - annotations = getattr(fn, "__annotations__", {}) + annotations = typing.get_type_hints(fn, include_extras=True) if parse_docstring: description, arg_descriptions = _parse_python_function_docstring( fn, annotations, error_on_invalid_docstring=error_on_invalid_docstring diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 4819634c9e5..f724bce74fd 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2711,3 +2711,24 @@ def test_tool_invoke_does_not_mutate_inputs() -> None: "id": "call_0_82c17db8-95df-452f-a4c2-03f809022134", "type": "tool_call", } + + +def test_tool_args_schema_with_annotated_type() -> None: + @tool + def test_tool( + query_fragments: Annotated[ + list[str], + "A list of query fragments", + ], + ) -> list[str]: + """Search the Internet and retrieve relevant result items.""" + return [] + + assert test_tool.args == { + "query_fragments": { + "description": "A list of query fragments", + "items": {"type": "string"}, + "title": "Query Fragments", + "type": "array", + } + }