diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 28a1ab73513..b6c902872d4 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -5,6 +5,7 @@ from __future__ import annotations import collections import inspect import logging +import types import typing import uuid from typing import ( @@ -575,6 +576,10 @@ def _parse_google_docstring( def _py_38_safe_origin(origin: Type) -> Type: + origin_union_type_map: Dict[Type, Any] = ( + {types.UnionType: Union} if hasattr(types, "UnionType") else {} + ) + origin_map: Dict[Type, Any] = { dict: Dict, list: List, @@ -584,5 +589,6 @@ def _py_38_safe_origin(origin: Type) -> Type: collections.abc.Mapping: typing.Mapping, collections.abc.Sequence: typing.Sequence, collections.abc.MutableMapping: typing.MutableMapping, + **origin_union_type_map, } return cast(Type, origin_map.get(origin, origin)) diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index daa981d3143..7c68cd24a2b 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -1,4 +1,5 @@ # mypy: disable-error-code="annotation-unchecked" +import sys from typing import ( Any, Callable, @@ -702,3 +703,18 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None: with pytest.raises(TypeError): _convert_typed_dict_to_openai_function(Tool) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="Requires python version >= 3.10 to run." +) +def test_convert_union_type_py_39() -> None: + @tool + def magic_function(input: int | float) -> str: + """Compute a magic function.""" + pass + + result = convert_to_openai_function(magic_function) + assert result["parameters"]["properties"]["input"] == { + "anyOf": [{"type": "integer"}, {"type": "number"}] + }