diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index cf168e784b2..90212af424f 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -28,7 +28,20 @@ from abc import ABC, abstractmethod from contextvars import copy_context from functools import partial from inspect import signature -from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from typing_extensions import Annotated, get_args, get_origin from langchain_core._api import deprecated from langchain_core.callbacks import ( @@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError): """Raised when 'args_schema' is missing or has an incorrect type annotation.""" +def _is_annotated_type(typ: Type[Any]) -> bool: + return get_origin(typ) is Annotated + + +def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None: + if _is_annotated_type(arg_type): + annotated_args = get_args(arg_type) + arg_type = annotated_args[0] + if len(annotated_args) > 1: + for annotation in annotated_args[1:]: + if isinstance(annotation, str): + return annotation + return None + + def _create_subset_model( - name: str, model: Type[BaseModel], field_names: list + name: str, + model: Type[BaseModel], + field_names: list, + *, + descriptions: Optional[dict] = None, + fn_description: Optional[str] = None, ) -> Type[BaseModel]: """Create a pydantic model with only a subset of model's fields.""" fields = {} + for field_name in field_names: field = model.__fields__[field_name] t = ( @@ -89,19 +123,89 @@ def _create_subset_model( if field.required and not field.allow_none else Optional[field.outer_type_] ) + if descriptions and field_name in descriptions: + field.field_info.description = descriptions[field_name] fields[field_name] = (t, field.field_info) + rtn = create_model(name, **fields) # type: ignore + rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") return rtn def _get_filtered_args( inferred_model: Type[BaseModel], func: Callable, + *, + filter_args: Sequence[str], ) -> dict: """Get the arguments from a function's signature.""" schema = inferred_model.schema()["properties"] valid_keys = signature(func).parameters - return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")} + return { + k: schema[k] + for i, (k, param) in enumerate(valid_keys.items()) + if k not in filter_args and (i > 0 or param.name not in ("self", "cls")) + } + + +def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: + """Parse the function and argument descriptions from the docstring of a function. + + Assumes the function docstring follows Google Python style guide. + """ + docstring = inspect.getdoc(function) + if docstring: + docstring_blocks = docstring.split("\n\n") + descriptors = [] + args_block = None + past_descriptors = False + for block in docstring_blocks: + if block.startswith("Args:"): + args_block = block + break + elif block.startswith("Returns:") or block.startswith("Example:"): + # Don't break in case Args come after + past_descriptors = True + elif not past_descriptors: + descriptors.append(block) + else: + continue + description = " ".join(descriptors) + else: + description = "" + args_block = None + arg_descriptions = {} + if args_block: + arg = None + for line in args_block.split("\n")[1:]: + if ":" in line: + arg, desc = line.split(":", maxsplit=1) + arg_descriptions[arg.strip()] = desc.strip() + elif arg: + arg_descriptions[arg.strip()] += " " + line.strip() + return description, arg_descriptions + + +def _infer_arg_descriptions( + fn: Callable, *, parse_docstring: bool = False +) -> Tuple[str, dict]: + """Infer argument descriptions from a function's docstring.""" + if parse_docstring: + description, arg_descriptions = _parse_python_function_docstring(fn) + else: + description = inspect.getdoc(fn) or "" + arg_descriptions = {} + if hasattr(inspect, "get_annotations"): + # This is for python < 3.10 + annotations = inspect.get_annotations(fn) # type: ignore + else: + annotations = getattr(fn, "__annotations__", {}) + for arg, arg_type in annotations.items(): + if arg in arg_descriptions: + continue + if desc := _get_annotation_description(arg, arg_type): + arg_descriptions[arg] = desc + return description, arg_descriptions class _SchemaConfig: @@ -114,25 +218,40 @@ class _SchemaConfig: def create_schema_from_function( model_name: str, func: Callable, + *, + filter_args: Optional[Sequence[str]] = None, + parse_docstring: bool = False, ) -> Type[BaseModel]: """Create a pydantic schema from a function's signature. Args: model_name: Name to assign to the generated pydandic schema func: Function to generate the schema from + filter_args: Optional list of arguments to exclude from the schema + parse_docstring: Whether to parse the function's docstring for descriptions + for each argument. Returns: A pydantic model with the same arguments as the function """ # https://docs.pydantic.dev/latest/usage/validation_decorator/ validated = validate_arguments(func, config=_SchemaConfig) # type: ignore inferred_model = validated.model # type: ignore - if "run_manager" in inferred_model.__fields__: - del inferred_model.__fields__["run_manager"] - if "callbacks" in inferred_model.__fields__: - del inferred_model.__fields__["callbacks"] + filter_args = ( + filter_args if filter_args is not None else ("run_manager", "callbacks") + ) + for arg in filter_args: + if arg in inferred_model.__fields__: + del inferred_model.__fields__[arg] + description, arg_descriptions = _infer_arg_descriptions( + func, parse_docstring=parse_docstring + ) # Pydantic adds placeholder virtual fields we need to strip - valid_properties = _get_filtered_args(inferred_model, func) + valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args) return _create_subset_model( - f"{model_name}Schema", inferred_model, list(valid_properties) + f"{model_name}Schema", + inferred_model, + list(valid_properties), + descriptions=arg_descriptions, + fn_description=description, ) diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index e8b1e32d1cd..f4bcba6e701 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -2,10 +2,8 @@ from __future__ import annotations -import inspect import logging import uuid -from types import FunctionType, MethodType from typing import ( TYPE_CHECKING, Any, @@ -14,13 +12,12 @@ from typing import ( List, Literal, Optional, - Tuple, Type, Union, cast, ) -from typing_extensions import Annotated, TypedDict, get_args, get_origin +from typing_extensions import TypedDict from langchain_core._api import deprecated from langchain_core.messages import ( @@ -123,122 +120,6 @@ def _get_python_function_name(function: Callable) -> str: return function.__name__ -def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]: - """Parse the function and argument descriptions from the docstring of a function. - - Assumes the function docstring follows Google Python style guide. - """ - docstring = inspect.getdoc(function) - if docstring: - docstring_blocks = docstring.split("\n\n") - descriptors = [] - args_block = None - past_descriptors = False - for block in docstring_blocks: - if block.startswith("Args:"): - args_block = block - break - elif block.startswith("Returns:") or block.startswith("Example:"): - # Don't break in case Args come after - past_descriptors = True - elif not past_descriptors: - descriptors.append(block) - else: - continue - description = " ".join(descriptors) - else: - description = "" - args_block = None - arg_descriptions = {} - if args_block: - arg = None - for line in args_block.split("\n")[1:]: - if ":" in line: - arg, desc = line.split(":", maxsplit=1) - arg_descriptions[arg.strip()] = desc.strip() - elif arg: - arg_descriptions[arg.strip()] += " " + line.strip() - return description, arg_descriptions - - -def _is_annotated_type(typ: Type[Any]) -> bool: - return get_origin(typ) is Annotated - - -def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict: - """Get JsonSchema describing a Python functions arguments. - - Assumes all function arguments are of primitive types (int, float, str, bool) or - are subclasses of pydantic.BaseModel. - """ - properties = {} - annotations = inspect.getfullargspec(function).annotations - for arg, arg_type in annotations.items(): - if arg == "return": - continue - - if _is_annotated_type(arg_type): - annotated_args = get_args(arg_type) - arg_type = annotated_args[0] - if len(annotated_args) > 1: - for annotation in annotated_args[1:]: - if isinstance(annotation, str): - arg_descriptions[arg] = annotation - break - if ( - isinstance(arg_type, type) - and hasattr(arg_type, "model_json_schema") - and callable(arg_type.model_json_schema) - ): - properties[arg] = arg_type.model_json_schema() - elif ( - isinstance(arg_type, type) - and hasattr(arg_type, "schema") - and callable(arg_type.schema) - ): - properties[arg] = arg_type.schema() - elif ( - hasattr(arg_type, "__name__") - and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES - ): - properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]} - elif ( - hasattr(arg_type, "__dict__") - and getattr(arg_type, "__dict__").get("__origin__", None) == Literal - ): - properties[arg] = { - "enum": list(arg_type.__args__), - "type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__], - } - else: - logger.warning( - f"Argument {arg} of type {arg_type} from function {function.__name__} " - "could not be not be converted to a JSON schema." - ) - - if arg in arg_descriptions: - if arg not in properties: - properties[arg] = {} - properties[arg]["description"] = arg_descriptions[arg] - - return properties - - -def _get_python_function_required_args(function: Callable) -> List[str]: - """Get the required arguments for a Python function.""" - spec = inspect.getfullargspec(function) - required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args - required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})] - - is_function_type = isinstance(function, FunctionType) - is_method_type = isinstance(function, MethodType) - if required and is_function_type and required[0] == "self": - required = required[1:] - elif required and is_method_type and required[0] == "cls": - required = required[1:] - return required - - @deprecated( "0.1.16", alternative="langchain_core.utils.function_calling.convert_to_openai_function()", @@ -246,23 +127,24 @@ def _get_python_function_required_args(function: Callable) -> List[str]: ) def convert_python_function_to_openai_function( function: Callable, -) -> Dict[str, Any]: +) -> FunctionDescription: """Convert a Python function to an OpenAI function-calling API compatible dict. Assumes the Python function has type hints and a docstring with a description. If the docstring has Google Python style argument descriptions, these will be included as well. """ - description, arg_descriptions = _parse_python_function_docstring(function) - return { - "name": _get_python_function_name(function), - "description": description, - "parameters": { - "type": "object", - "properties": _get_python_function_arguments(function, arg_descriptions), - "required": _get_python_function_required_args(function), - }, - } + from langchain_core import tools + + func_name = _get_python_function_name(function) + model = tools.create_schema_from_function( + func_name, function, filter_args=(), parse_docstring=True + ) + return convert_pydantic_to_openai_function( + model, + name=func_name, + description=model.__doc__, + ) @deprecated( @@ -343,7 +225,7 @@ def convert_to_openai_function( elif isinstance(function, BaseTool): return cast(Dict, format_tool_to_openai_function(function)) elif callable(function): - return convert_python_function_to_openai_function(function) + return cast(Dict, convert_python_function_to_openai_function(function)) else: raise ValueError( f"Unsupported function\n\n{function}\n\nFunctions must be passed in" diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 5e7b05dbff3..968b2a03d46 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1,6 +1,7 @@ """Test the base tool implementation.""" import asyncio +import inspect import json import sys import textwrap @@ -10,6 +11,7 @@ from functools import partial from typing import Any, Callable, Dict, List, Optional, Type, Union import pytest +from typing_extensions import Annotated from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, @@ -310,9 +312,10 @@ def test_structured_tool_from_function_docstring() -> None: def foo(bar: int, baz: str) -> str: """Docstring + Args: - bar: int - baz: str + bar: the bar value + baz: the baz value """ raise NotImplementedError() @@ -328,6 +331,7 @@ def test_structured_tool_from_function_docstring() -> None: "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, }, + "description": inspect.getdoc(foo), "title": "fooSchema", "type": "object", "required": ["bar", "baz"], @@ -342,6 +346,7 @@ def test_structured_tool_from_function_docstring_complex_args() -> None: def foo(bar: int, baz: List[str]) -> str: """Docstring + Args: bar: int baz: List[str] @@ -352,14 +357,23 @@ def test_structured_tool_from_function_docstring_complex_args() -> None: assert structured_tool.name == "foo" assert structured_tool.args == { "bar": {"title": "Bar", "type": "integer"}, - "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + "baz": { + "title": "Baz", + "type": "array", + "items": {"type": "string"}, + }, } assert structured_tool.args_schema.schema() == { "properties": { "bar": {"title": "Bar", "type": "integer"}, - "baz": {"title": "Baz", "type": "array", "items": {"type": "string"}}, + "baz": { + "title": "Baz", + "type": "array", + "items": {"type": "string"}, + }, }, + "description": inspect.getdoc(foo), "title": "fooSchema", "type": "object", "required": ["bar", "baz"], @@ -439,6 +453,7 @@ def test_structured_tool_from_function_with_run_manager() -> None: bar: int, baz: str, callbacks: Optional[CallbackManagerForToolRun] = None ) -> str: """Docstring + Args: bar: int baz: str @@ -459,6 +474,7 @@ def test_structured_tool_from_function_with_run_manager() -> None: "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, }, + "description": inspect.getdoc(foo), "title": "fooSchema", "type": "object", "required": ["bar", "baz"], @@ -675,10 +691,11 @@ def test_structured_tool_from_function() -> None: """Test that structured tools can be created from functions.""" def foo(bar: int, baz: str) -> str: - """Docstring + """Docstring thing. + Args: - bar: int - baz: str + bar: the bar value + baz: the baz value """ raise NotImplementedError() @@ -692,6 +709,7 @@ def test_structured_tool_from_function() -> None: assert structured_tool.args_schema.schema() == { "title": "fooSchema", "type": "object", + "description": inspect.getdoc(foo), "properties": { "bar": {"title": "Bar", "type": "integer"}, "baz": {"title": "Baz", "type": "string"}, @@ -916,3 +934,56 @@ def test_tool_description() -> None: foo2 = StructuredTool.from_function(foo) assert foo2.description == "The foo." + + +def test_tool_arg_descriptions() -> None: + def foo(bar: str, baz: int) -> str: + """The foo. + + Args: + bar: The bar. + baz: The baz. + """ + return bar + + foo1 = tool(foo) + args_schema = foo1.args_schema.schema() # type: ignore + assert args_schema == { + "title": "fooSchema", + "type": "object", + "description": inspect.getdoc(foo), + "properties": { + "bar": {"title": "Bar", "type": "string"}, + "baz": {"title": "Baz", "type": "integer"}, + }, + "required": ["bar", "baz"], + } + + +def test_tool_annotated_descriptions() -> None: + def foo( + bar: Annotated[str, "this is the bar"], baz: Annotated[int, "this is the baz"] + ) -> str: + """The foo. + + Returns: + The bar only. + """ + return bar + + foo1 = tool(foo) + args_schema = foo1.args_schema.schema() # type: ignore + assert args_schema == { + "title": "fooSchema", + "type": "object", + "description": inspect.getdoc(foo), + "properties": { + "bar": {"title": "Bar", "type": "string", "description": "this is the bar"}, + "baz": { + "title": "Baz", + "type": "integer", + "description": "this is the baz", + }, + }, + "required": ["bar", "baz"], + } diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index ddf19e87e9a..6702b5ad635 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -252,7 +252,7 @@ def test_function_no_params() -> None: pass func = convert_to_openai_function(nullary_function) - req = func["parameters"]["required"] + req = func["parameters"].get("required") assert not req