mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
[Core] Add support for inferring Annotated types (#23284)
in bind_tools() / convert_to_openai_function
This commit is contained in:
parent
9ac302cb97
commit
efb4c12abe
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from types import FunctionType, MethodType
|
from types import FunctionType, MethodType
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -19,7 +20,7 @@ from typing import (
|
|||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import Annotated, TypedDict, get_args, get_origin
|
||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -33,7 +34,7 @@ from langchain_core.utils.json_schema import dereference_refs
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
PYTHON_TO_JSON_TYPES = {
|
PYTHON_TO_JSON_TYPES = {
|
||||||
"str": "string",
|
"str": "string",
|
||||||
"int": "integer",
|
"int": "integer",
|
||||||
@ -160,6 +161,10 @@ def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
|
|||||||
return description, arg_descriptions
|
return description, arg_descriptions
|
||||||
|
|
||||||
|
|
||||||
|
def _is_annotated_type(typ: Type[Any]) -> bool:
|
||||||
|
return get_origin(typ) is Annotated
|
||||||
|
|
||||||
|
|
||||||
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
|
||||||
"""Get JsonSchema describing a Python functions arguments.
|
"""Get JsonSchema describing a Python functions arguments.
|
||||||
|
|
||||||
@ -171,10 +176,27 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
|
|||||||
for arg, arg_type in annotations.items():
|
for arg, arg_type in annotations.items():
|
||||||
if arg == "return":
|
if arg == "return":
|
||||||
continue
|
continue
|
||||||
if isinstance(arg_type, type) and issubclass(arg_type, BaseModel):
|
|
||||||
# Mypy error:
|
if _is_annotated_type(arg_type):
|
||||||
# "type" has no attribute "schema"
|
annotated_args = get_args(arg_type)
|
||||||
properties[arg] = arg_type.schema() # type: ignore[attr-defined]
|
arg_type = annotated_args[0]
|
||||||
|
if len(annotated_args) > 1:
|
||||||
|
for annotation in annotated_args[1:]:
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
arg_descriptions[arg] = annotation
|
||||||
|
break
|
||||||
|
if (
|
||||||
|
isinstance(arg_type, type)
|
||||||
|
and hasattr(arg_type, "model_json_schema")
|
||||||
|
and callable(arg_type.model_json_schema)
|
||||||
|
):
|
||||||
|
properties[arg] = arg_type.model_json_schema()
|
||||||
|
elif (
|
||||||
|
isinstance(arg_type, type)
|
||||||
|
and hasattr(arg_type, "schema")
|
||||||
|
and callable(arg_type.schema)
|
||||||
|
):
|
||||||
|
properties[arg] = arg_type.schema()
|
||||||
elif (
|
elif (
|
||||||
hasattr(arg_type, "__name__")
|
hasattr(arg_type, "__name__")
|
||||||
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
|
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
|
||||||
@ -185,13 +207,20 @@ def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -
|
|||||||
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
|
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
|
||||||
):
|
):
|
||||||
properties[arg] = {
|
properties[arg] = {
|
||||||
"enum": list(arg_type.__args__), # type: ignore
|
"enum": list(arg_type.__args__),
|
||||||
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], # type: ignore
|
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Argument {arg} of type {arg_type} from function {function.__name__} "
|
||||||
|
"could not be not be converted to a JSON schema."
|
||||||
|
)
|
||||||
|
|
||||||
if arg in arg_descriptions:
|
if arg in arg_descriptions:
|
||||||
if arg not in properties:
|
if arg not in properties:
|
||||||
properties[arg] = {}
|
properties[arg] = {}
|
||||||
properties[arg]["description"] = arg_descriptions[arg]
|
properties[arg]["description"] = arg_descriptions[arg]
|
||||||
|
|
||||||
return properties
|
return properties
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
# mypy: disable-error-code="annotation-unchecked"
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel as BaseModelV2Maybe # pydantic: ignore
|
||||||
|
from pydantic import Field as FieldV2Maybe # pydantic: ignore
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
@ -22,6 +26,18 @@ def pydantic() -> Type[BaseModel]:
|
|||||||
return dummy_function
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def annotated_function() -> Callable:
|
||||||
|
def dummy_function(
|
||||||
|
arg1: Annotated[int, "foo"],
|
||||||
|
arg2: Annotated[Literal["bar", "baz"], "one of 'bar', 'baz'"],
|
||||||
|
) -> None:
|
||||||
|
"""dummy function"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def function() -> Callable:
|
def function() -> Callable:
|
||||||
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
|
||||||
@ -53,6 +69,30 @@ def dummy_tool() -> BaseTool:
|
|||||||
return DummyFunction()
|
return DummyFunction()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_pydantic() -> Type[BaseModel]:
|
||||||
|
class dummy_function(BaseModel):
|
||||||
|
"""dummy function"""
|
||||||
|
|
||||||
|
arg1: int = Field(..., description="foo")
|
||||||
|
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||||
|
|
||||||
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def dummy_pydantic_v2() -> Type[BaseModelV2Maybe]:
|
||||||
|
class dummy_function(BaseModelV2Maybe):
|
||||||
|
"""dummy function"""
|
||||||
|
|
||||||
|
arg1: int = FieldV2Maybe(..., description="foo")
|
||||||
|
arg2: Literal["bar", "baz"] = FieldV2Maybe(
|
||||||
|
..., description="one of 'bar', 'baz'"
|
||||||
|
)
|
||||||
|
|
||||||
|
return dummy_function
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def json_schema() -> Dict:
|
def json_schema() -> Dict:
|
||||||
return {
|
return {
|
||||||
@ -99,6 +139,8 @@ def test_convert_to_openai_function(
|
|||||||
function: Callable,
|
function: Callable,
|
||||||
dummy_tool: BaseTool,
|
dummy_tool: BaseTool,
|
||||||
json_schema: Dict,
|
json_schema: Dict,
|
||||||
|
annotated_function: Callable,
|
||||||
|
dummy_pydantic: Type[BaseModel],
|
||||||
) -> None:
|
) -> None:
|
||||||
expected = {
|
expected = {
|
||||||
"name": "dummy_function",
|
"name": "dummy_function",
|
||||||
@ -125,11 +167,69 @@ def test_convert_to_openai_function(
|
|||||||
expected,
|
expected,
|
||||||
Dummy.dummy_function,
|
Dummy.dummy_function,
|
||||||
DummyWithClassMethod.dummy_function,
|
DummyWithClassMethod.dummy_function,
|
||||||
|
annotated_function,
|
||||||
|
dummy_pydantic,
|
||||||
):
|
):
|
||||||
actual = convert_to_openai_function(fn) # type: ignore
|
actual = convert_to_openai_function(fn) # type: ignore
|
||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_convert_to_openai_function_nested() -> None:
|
||||||
|
class Nested(BaseModel):
|
||||||
|
nested_arg1: int = Field(..., description="foo")
|
||||||
|
nested_arg2: Literal["bar", "baz"] = Field(
|
||||||
|
..., description="one of 'bar', 'baz'"
|
||||||
|
)
|
||||||
|
|
||||||
|
class NestedV2(BaseModelV2Maybe):
|
||||||
|
nested_v2_arg1: int = FieldV2Maybe(..., description="foo")
|
||||||
|
nested_v2_arg2: Literal["bar", "baz"] = FieldV2Maybe(
|
||||||
|
..., description="one of 'bar', 'baz'"
|
||||||
|
)
|
||||||
|
|
||||||
|
def my_function(arg1: Nested, arg2: NestedV2) -> None:
|
||||||
|
"""dummy function"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
"name": "my_function",
|
||||||
|
"description": "dummy function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"arg1": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nested_arg1": {"type": "integer", "description": "foo"},
|
||||||
|
"nested_arg2": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["bar", "baz"],
|
||||||
|
"description": "one of 'bar', 'baz'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["nested_arg1", "nested_arg2"],
|
||||||
|
},
|
||||||
|
"arg2": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"nested_v2_arg1": {"type": "integer", "description": "foo"},
|
||||||
|
"nested_v2_arg2": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["bar", "baz"],
|
||||||
|
"description": "one of 'bar', 'baz'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["nested_v2_arg1", "nested_v2_arg2"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["arg1", "arg2"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
actual = convert_to_openai_function(my_function)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
|
@pytest.mark.xfail(reason="Pydantic converts Optional[str] to str in .schema()")
|
||||||
def test_function_optional_param() -> None:
|
def test_function_optional_param() -> None:
|
||||||
@tool
|
@tool
|
||||||
|
Loading…
Reference in New Issue
Block a user