mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 05:43:55 +00:00
Fix UnionType type var replacement (#25566)
[langchain_core] Fix UnionType type var replacement - Added types.UnionType to typing.Union mapping Type replacement cause `TypeError: 'type' object is not subscriptable` if any of union type comes as function `_py_38_safe_origin` return `types.UnionType` instead of `typing.Union` ```python >>> from types import UnionType >>> from typing import Union, get_origin >>> type_ = get_origin(str | None) >>> type_ <class 'types.UnionType'> >>> UnionType[(str, None)] Traceback (most recent call last): File "<stdin>", line 1, in <module> TypeError: 'type' object is not subscriptable >>> Union[(str, None)] typing.Optional[str] ``` --------- Co-authored-by: Chester Curme <chester.curme@gmail.com>
This commit is contained in:
parent
8230ba47f3
commit
5b9290a449
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
import collections
|
import collections
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import types
|
||||||
import typing
|
import typing
|
||||||
import uuid
|
import uuid
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -575,6 +576,10 @@ def _parse_google_docstring(
|
|||||||
|
|
||||||
|
|
||||||
def _py_38_safe_origin(origin: Type) -> Type:
|
def _py_38_safe_origin(origin: Type) -> Type:
|
||||||
|
origin_union_type_map: Dict[Type, Any] = (
|
||||||
|
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||||
|
)
|
||||||
|
|
||||||
origin_map: Dict[Type, Any] = {
|
origin_map: Dict[Type, Any] = {
|
||||||
dict: Dict,
|
dict: Dict,
|
||||||
list: List,
|
list: List,
|
||||||
@ -584,5 +589,6 @@ def _py_38_safe_origin(origin: Type) -> Type:
|
|||||||
collections.abc.Mapping: typing.Mapping,
|
collections.abc.Mapping: typing.Mapping,
|
||||||
collections.abc.Sequence: typing.Sequence,
|
collections.abc.Sequence: typing.Sequence,
|
||||||
collections.abc.MutableMapping: typing.MutableMapping,
|
collections.abc.MutableMapping: typing.MutableMapping,
|
||||||
|
**origin_union_type_map,
|
||||||
}
|
}
|
||||||
return cast(Type, origin_map.get(origin, origin))
|
return cast(Type, origin_map.get(origin, origin))
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# mypy: disable-error-code="annotation-unchecked"
|
# mypy: disable-error-code="annotation-unchecked"
|
||||||
|
import sys
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@ -702,3 +703,18 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: Type) -> None:
|
|||||||
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
_convert_typed_dict_to_openai_function(Tool)
|
_convert_typed_dict_to_openai_function(Tool)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 10), reason="Requires python version >= 3.10 to run."
|
||||||
|
)
|
||||||
|
def test_convert_union_type_py_39() -> None:
|
||||||
|
@tool
|
||||||
|
def magic_function(input: int | float) -> str:
|
||||||
|
"""Compute a magic function."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = convert_to_openai_function(magic_function)
|
||||||
|
assert result["parameters"]["properties"]["input"] == {
|
||||||
|
"anyOf": [{"type": "integer"}, {"type": "number"}]
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user