diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index d7059fded47..970f0f932ad 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -61,38 +61,38 @@ class ToolDescription(TypedDict): """The function description.""" -def _rm_titles(kv: dict, prev_key: str = "") -> dict: - """Recursively removes "title" fields from a JSON schema dictionary. +def _rm_titles(kv: dict) -> dict: + """Recursively removes all "title" fields from a JSON schema dictionary. - Remove "title" fields from the input JSON schema dictionary, - except when a "title" appears within a property definition under "properties". - - Args: - kv (dict): The input JSON schema as a dictionary. - prev_key (str): The key from the parent dictionary, used to identify context. - - Returns: - dict: A new dictionary with appropriate "title" fields removed. + This is used to remove extraneous Pydantic schema titles. It is intelligent + enough to preserve fields that are legitimately named "title" within an + object's properties. """ - new_kv = {} - for k, v in kv.items(): - if k == "title": - # If the value is a nested dict and part of a property under "properties", - # preserve the title but continue recursion - if isinstance(v, dict) and prev_key == "properties": - new_kv[k] = _rm_titles(v, k) - else: - # Otherwise, remove this "title" key - continue - elif isinstance(v, dict): - # Recurse into nested dictionaries - new_kv[k] = _rm_titles(v, k) - else: - # Leave non-dict values untouched - new_kv[k] = v + def inner(obj: Any, *, in_properties: bool = False) -> Any: + if isinstance(obj, dict): + if in_properties: + # We are inside a 'properties' block. Keys here are valid + # field names (e.g., "title") and should be kept. We + # recurse on the values, resetting the flag. + return {k: inner(v, in_properties=False) for k, v in obj.items()} - return new_kv + # We are at a schema level. The 'title' key is metadata and should be + # removed. + out = {} + for k, v in obj.items(): + if k == "title": + continue + # Recurse, setting the flag only if the key is 'properties'. + out[k] = inner(v, in_properties=(k == "properties")) + return out + if isinstance(obj, list): + # Recurse on items in a list. + return [inner(item, in_properties=in_properties) for item in obj] + # Return non-dict, non-list values as is. + return obj + + return inner(kv) def _convert_json_schema_to_openai_function( @@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript _MAX_TYPED_DICT_RECURSION = 25 +def _parse_google_docstring( + docstring: Optional[str], + args: list[str], + *, + error_on_invalid_docstring: bool = False, +) -> tuple[str, dict]: + """Parse the function and argument descriptions from the docstring of a function. + + Assumes the function docstring follows Google Python style guide. + """ + if docstring: + docstring_blocks = docstring.split("\n\n") + if error_on_invalid_docstring: + filtered_annotations = { + arg for arg in args if arg not in {"run_manager", "callbacks", "return"} + } + if filtered_annotations and ( + len(docstring_blocks) < 2 + or not any(block.startswith("Args:") for block in docstring_blocks[1:]) + ): + msg = "Found invalid Google-Style docstring." + raise ValueError(msg) + descriptors = [] + args_block = None + past_descriptors = False + for block in docstring_blocks: + if block.startswith("Args:"): + args_block = block + break + if block.startswith(("Returns:", "Example:")): + # Don't break in case Args come after + past_descriptors = True + elif not past_descriptors: + descriptors.append(block) + else: + continue + description = " ".join(descriptors) + else: + if error_on_invalid_docstring: + msg = "Found invalid Google-Style docstring." + raise ValueError(msg) + description = "" + args_block = None + arg_descriptions = {} + if args_block: + arg = None + for line in args_block.split("\n")[1:]: + if ":" in line: + arg, desc = line.split(":", maxsplit=1) + arg = arg.strip() + arg_name, _, annotations_ = arg.partition(" ") + if annotations_.startswith("(") and annotations_.endswith(")"): + arg = arg_name + arg_descriptions[arg] = desc.strip() + elif arg: + arg_descriptions[arg] += " " + line.strip() + return description, arg_descriptions + + def _convert_any_typed_dicts_to_pydantic( type_: type, *, @@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic( new_arg_type = _convert_any_typed_dicts_to_pydantic( annotated_args[0], depth=depth + 1, visited=visited ) - field_kwargs = dict(zip(("default", "description"), annotated_args[1:])) + field_kwargs = {} + metadata = annotated_args[1:] + if len(metadata) == 1 and isinstance(metadata[0], str): + # Case: Annotated[int, "a description"] + field_kwargs["description"] = metadata[0] + elif len(metadata) > 0: + # Case: Annotated[int, default_val, "a description"] + field_kwargs["default"] = metadata[0] + if len(metadata) > 1 and isinstance(metadata[1], str): + field_kwargs["description"] = metadata[1] + if (field_desc := field_kwargs.get("description")) and not isinstance( field_desc, str ): msg = ( - f"Invalid annotation for field {arg}. Third argument to " - f"Annotated must be a string description, received value of " - f"type {type(field_desc)}." + f"Invalid annotation for field {arg}. " + "Description must be a string." ) raise ValueError(msg) if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc + fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) else: new_arg_type = _convert_any_typed_dicts_to_pydantic( @@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic( return type_ +def _py_38_safe_origin(origin: type) -> type: + origin_union_type_map: dict[type, Any] = ( + {types.UnionType: Union} if hasattr(types, "UnionType") else {} + ) + + origin_map: dict[type, Any] = { + dict: dict, + list: list, + tuple: tuple, + set: set, + collections.abc.Iterable: typing.Iterable, + collections.abc.Mapping: typing.Mapping, + collections.abc.Sequence: typing.Sequence, + collections.abc.MutableMapping: typing.MutableMapping, + **origin_union_type_map, + } + return cast("type", origin_map.get(origin, origin)) + + def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: """Format tool into the OpenAI function API. @@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription: return {"type": "function", "function": function} +def _recursive_set_additional_properties_false( + schema: dict[str, Any], +) -> dict[str, Any]: + if isinstance(schema, dict): + # Check if 'required' is a key at the current level or if the schema is empty, + # in which case additionalProperties still needs to be specified. + if "required" in schema or ( + "properties" in schema and not schema["properties"] + ): + schema["additionalProperties"] = False + + # Recursively check 'properties' and 'items' if they exist + if "anyOf" in schema: + for sub_schema in schema["anyOf"]: + _recursive_set_additional_properties_false(sub_schema) + if "properties" in schema: + for sub_schema in schema["properties"].values(): + _recursive_set_additional_properties_false(sub_schema) + if "items" in schema: + _recursive_set_additional_properties_false(schema["items"]) + + return schema + + def convert_to_openai_function( function: Union[dict[str, Any], type, Callable, BaseTool], *, @@ -717,105 +829,3 @@ def tool_example_to_messages( if ai_response: messages.append(AIMessage(content=ai_response)) return messages - - -def _parse_google_docstring( - docstring: Optional[str], - args: list[str], - *, - error_on_invalid_docstring: bool = False, -) -> tuple[str, dict]: - """Parse the function and argument descriptions from the docstring of a function. - - Assumes the function docstring follows Google Python style guide. - """ - if docstring: - docstring_blocks = docstring.split("\n\n") - if error_on_invalid_docstring: - filtered_annotations = { - arg for arg in args if arg not in {"run_manager", "callbacks", "return"} - } - if filtered_annotations and ( - len(docstring_blocks) < 2 - or not any(block.startswith("Args:") for block in docstring_blocks[1:]) - ): - msg = "Found invalid Google-Style docstring." - raise ValueError(msg) - descriptors = [] - args_block = None - past_descriptors = False - for block in docstring_blocks: - if block.startswith("Args:"): - args_block = block - break - if block.startswith(("Returns:", "Example:")): - # Don't break in case Args come after - past_descriptors = True - elif not past_descriptors: - descriptors.append(block) - else: - continue - description = " ".join(descriptors) - else: - if error_on_invalid_docstring: - msg = "Found invalid Google-Style docstring." - raise ValueError(msg) - description = "" - args_block = None - arg_descriptions = {} - if args_block: - arg = None - for line in args_block.split("\n")[1:]: - if ":" in line: - arg, desc = line.split(":", maxsplit=1) - arg = arg.strip() - arg_name, _, annotations_ = arg.partition(" ") - if annotations_.startswith("(") and annotations_.endswith(")"): - arg = arg_name - arg_descriptions[arg] = desc.strip() - elif arg: - arg_descriptions[arg] += " " + line.strip() - return description, arg_descriptions - - -def _py_38_safe_origin(origin: type) -> type: - origin_union_type_map: dict[type, Any] = ( - {types.UnionType: Union} if hasattr(types, "UnionType") else {} - ) - - origin_map: dict[type, Any] = { - dict: dict, - list: list, - tuple: tuple, - set: set, - collections.abc.Iterable: typing.Iterable, - collections.abc.Mapping: typing.Mapping, - collections.abc.Sequence: typing.Sequence, - collections.abc.MutableMapping: typing.MutableMapping, - **origin_union_type_map, - } - return cast("type", origin_map.get(origin, origin)) - - -def _recursive_set_additional_properties_false( - schema: dict[str, Any], -) -> dict[str, Any]: - if isinstance(schema, dict): - # Check if 'required' is a key at the current level or if the schema is empty, - # in which case additionalProperties still needs to be specified. - if "required" in schema or ( - "properties" in schema and not schema["properties"] - ): - schema["additionalProperties"] = False - - # Recursively check 'properties' and 'items' if they exist - if "anyOf" in schema: - for sub_schema in schema["anyOf"]: - _recursive_set_additional_properties_false(sub_schema) - if "properties" in schema: - for sub_schema in schema["properties"].values(): - _recursive_set_additional_properties_false(sub_schema) - if "items" in schema: - _recursive_set_additional_properties_false(schema["items"]) - - return schema diff --git a/libs/core/tests/unit_tests/utils/test_function_calling.py b/libs/core/tests/unit_tests/utils/test_function_calling.py index f75ae304937..5ab9abd43e9 100644 --- a/libs/core/tests/unit_tests/utils/test_function_calling.py +++ b/libs/core/tests/unit_tests/utils/test_function_calling.py @@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import ( ) +def remove_titles(obj: dict) -> None: + if isinstance(obj, dict): + obj.pop("title", None) + for v in obj.values(): + remove_titles(v) + elif isinstance(obj, list): + for v in obj: + remove_titles(v) + return obj + + @pytest.fixture def pydantic() -> type[BaseModel]: class dummy_function(BaseModel): # noqa: N801 @@ -365,9 +376,9 @@ def test_convert_to_openai_function( dummy_extensions_typed_dict_docstring, ): actual = convert_to_openai_function(fn) + remove_titles(actual) assert actual == expected - # Test runnables actual = convert_to_openai_function(runnable.as_tool(description="Dummy function.")) parameters = { "type": "object", @@ -384,7 +395,6 @@ def test_convert_to_openai_function( runnable_expected["parameters"] = parameters assert actual == runnable_expected - # Test simple Tool def my_function(_: str) -> str: return "" @@ -398,11 +408,12 @@ def test_convert_to_openai_function( "name": "dummy_function", "description": "test description", "parameters": { - "properties": {"__arg1": {"title": "__arg1", "type": "string"}}, + "properties": {"__arg1": {"type": "string"}}, "required": ["__arg1"], "type": "object", }, } + remove_titles(actual) assert actual == expected @@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None: } actual = convert_to_openai_function(my_function) + remove_titles(actual) assert actual == expected @@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None: } actual = convert_to_openai_function(my_function, strict=True) + remove_titles(actual) assert actual == expected @@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None: "my_arg": { "anyOf": [ { - "properties": {"foo": {"title": "Foo", "type": "string"}}, + "properties": {"foo": {"type": "string"}}, "required": ["foo"], - "title": "NestedA", "type": "object", "additionalProperties": False, }, { - "properties": {"bar": {"title": "Bar", "type": "integer"}}, + "properties": {"bar": {"type": "integer"}}, "required": ["bar"], - "title": "NestedB", "type": "object", "additionalProperties": False, }, { - "properties": {"baz": {"title": "Baz", "type": "boolean"}}, + "properties": {"baz": {"type": "boolean"}}, "required": ["baz"], - "title": "NestedC", "type": "object", "additionalProperties": False, }, @@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None: } actual = convert_to_openai_function(my_function, strict=True) + remove_titles(actual) assert actual == expected @@ -556,7 +567,6 @@ json_schema_no_description_no_params = { "title": "dummy_function", } - json_schema_no_description = { "title": "dummy_function", "type": "object", @@ -571,7 +581,6 @@ json_schema_no_description = { "required": ["arg1", "arg2"], } - anthropic_tool_no_description = { "name": "dummy_function", "input_schema": { @@ -588,7 +597,6 @@ anthropic_tool_no_description = { }, } - bedrock_converse_tool_no_description = { "toolSpec": { "name": "dummy_function", @@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = { } } - openai_function_no_description = { "name": "dummy_function", "parameters": { @@ -626,7 +633,6 @@ openai_function_no_description = { }, } - openai_function_no_description_no_params = { "name": "dummy_function", } @@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None: }, } actual = convert_to_openai_function(func) + remove_titles(actual) assert actual == expected @@ -772,7 +779,6 @@ def test_tool_outputs() -> None: ] assert messages[2].content == "Output1" - # Test final AI response messages = tool_example_to_messages( input="This is an example", tool_calls=[ @@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function( "items": [ {"type": "array", "items": {}}, { - "title": "SubTool", "description": "Subtool docstring.", "type": "object", "properties": { "args": { - "title": "Args", "description": "this does bar", "default": {}, "type": "object", @@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function( "maxItems": 1, "items": [ { - "title": "SubTool", "description": "Subtool docstring.", "type": "object", "properties": { "args": { - "title": "Args", "description": "this does bar", "default": {}, "type": "object", @@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function( }, } actual = _convert_typed_dict_to_openai_function(Tool) + remove_titles(actual) assert actual == expected @@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None: class Tool(typed_dict): # type: ignore[misc] arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not. - # Error should be raised since we're using v1 code path here with pytest.raises(TypeError): _convert_typed_dict_to_openai_function(Tool) diff --git a/reproduce_pydanticv2_test.py b/reproduce_pydanticv2_test.py new file mode 100644 index 00000000000..8826da65545 --- /dev/null +++ b/reproduce_pydanticv2_test.py @@ -0,0 +1,141 @@ +import re +import os +import json +from typing import Literal, Optional, Tuple, Union, Annotated +from pydantic import BaseModel, Field, PositiveInt, ValidationInfo, field_validator, ConfigDict +from langchain_core.tools import tool +from langchain_core.messages import HumanMessage, AIMessage +from langchain_openai import ChatOpenAI + +# Ensure you have your OPENAI_API_KEY set as an environment variable +if not os.getenv("OPENAI_API_KEY"): + raise ValueError("OPENAI_API_KEY environment variable not set.") + +# Dummy placeholder since this isn't a real LangGraph state injection +def InjectedState(d: dict): + return {} + +# --- Pydantic Models from the GitHub Issue --- + +time_fmt = "%Y-%m-%d %H:%M:%S" +time_pattern = r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$" + +# Forward-declare nested models for Pydantic +class DataSoilDashboardQueryPayloadQueryParam: + pass + +class DataSoilDashboardQueryPayloadTimeShift(BaseModel): + shiftInterval: list[PositiveInt] = Field(description="Each element in the array represents a time offset relative to the query timestamp for individual time comparison analysis. If time comparison analysis dose not described, keep it **VOID**.",max_length=2,default=[]) + timeUnit: Literal["DAY"] = Field(default="DAY",description="The unit of specific comparison time offset. This is the description about each value of unit: Unit **DAY** represents one day.") + +class DataSoilDashboardQueryPayloadQueryParamWhereFilter(BaseModel): + field: str = Field(description="The dimension **CODE** in the selected dimension list that requires enums filtering or pattern filtering.") + operator: Literal["IN", "NI", "LIKE", "NOT_LIKE"] = Field(description="Operators for enums filtering or pattern filtering.") + value: list[str] = Field(description="If for enums filtering, every element represents th practical enums of the dimension. Otherwise for pattern filtering, only **one** element is required and it represents a wildcard pattern.",min_length=1) + + @field_validator("field") + def field_block(cls, v: str, info: ValidationInfo) -> str: + if v == "dt": + raise ValueError("Instruction: The time filtering should be described in 'time' field, not in the 'filters' field.") + return v + + @field_validator("value") + def value_block(cls, v: Optional[list[str]], info: ValidationInfo) -> Optional[list[str]]: + if info.data.get("operator") in {"LIKE", "NOT_LIKE"} and v and len(v) > 1: + raise ValueError("Instruction: For pattern filtering, the size of 'value' in 'where' must be **ONE**.") + return v + +class DataSoilDashboardQueryPayloadQueryParamWhere(BaseModel): + time: list[Union[str, int]] = Field(description=f"The target time range...", min_length=2, max_length=2) + filters: list[DataSoilDashboardQueryPayloadQueryParamWhereFilter] = Field(description="Enums filtering or pattern filtering condition...") + relation: Literal["AND"] = Field(description="Boolean relationships between filters...") + + @field_validator("time") + def time_format_block(cls, v: list[Union[int, str]], info: ValidationInfo) -> list[Union[int, str]]: + if isinstance(v[0], str) and not re.search(time_pattern, v[0]): + raise ValueError(f"Instruction: the start time of time range must be formatted as **{time_fmt}**") + if isinstance(v[1], str) and not re.search(time_pattern, v[1]): + raise ValueError(f"Instruction: the end time of time range must be formatted as **{time_fmt}**") + return v + +class DataSoilDashboardQueryPayloadQueryParamOrderBy(BaseModel): + field: str = Field(description="The metric **CODE** in the selected metric list that requires metric sorting.") + direction: Literal["ASC", "DESC"] = Field(description="Sorting direction for specified metric.") + shift: int = Field(default=0) + limit: int = Field(description="The number of rows to return...", default=50) + +class DataSoilDashboardQueryPayloadQueryParamGroupBy(BaseModel): + field: str = Field(description="The dimension **CODE** in the selected dimension list for dimension grouping analysis.") + extendFields: list[str] = Field(default=[]) + orderBy: Optional[DataSoilDashboardQueryPayloadQueryParamOrderBy] = Field(description="Sorting config for query results...", default=None) + +class DataSoilDashboardQueryPayloadQueryParam(BaseModel): + queryType: Literal["DETAIL_TABLE"] = Field(description="This is the description about queryType...") + interval: Literal["BY_ONE_MINUTE", "BY_FIVE_MINUTE", "BY_HOUR", "BY_DAY", "BY_WEEK", "BY_MONTH", "SUM"] = Field(description="The time granularity for time-based grouping analysis.") + resultField: list[str] = Field(default=[]) + where: DataSoilDashboardQueryPayloadQueryParamWhere = Field(description="Filtering condition for dimensions.") + groupBy: list[DataSoilDashboardQueryPayloadQueryParamGroupBy] = Field(description="A list of dimensions grouping analysis info...") + orderBy: DataSoilDashboardQueryPayloadQueryParamOrderBy = Field(description="Sorting config for query results...") + heavyQuery: bool = Field(default=False) + + @field_validator("groupBy") + def groupBy_block(cls, v: list[DataSoilDashboardQueryPayloadQueryParamGroupBy], info: ValidationInfo) -> list[DataSoilDashboardQueryPayloadQueryParamGroupBy]: + if "dt" in {e.field for e in v}: + if info.data.get("interval") == "SUM": + raise ValueError("Instruction: the interval can not be **SUM** when **time-based grouping is required**.") + else: + if info.data.get("interval") != "SUM": + raise ValueError("Instruction: the interval must be **SUM** when **time-based grouping is not required**.") + return v + +class DataSoilDashboardQueryPayload(BaseModel): + model_config = ConfigDict(frozen=False) + apiCode: str = Field(default="") + requestId: str = Field(default="") + applicationCode: str = Field(default="") + applicationToken: str = Field(default="") + debug: bool = Field(default=False) + timeShift: DataSoilDashboardQueryPayloadTimeShift = Field(description="Time comparison analysis config.", default_factory=DataSoilDashboardQueryPayloadTimeShift) + dynamicQueryParam: DataSoilDashboardQueryPayloadQueryParam + forceFlush: bool = Field(default=False) + +# Resolve forward references +DataSoilDashboardQueryPayload.model_rebuild() + +@tool +def query_datasoil_data_tool(payload: DataSoilDashboardQueryPayload) -> str: + """Queries the DataSoil database with a complex payload.""" + print("--- Tool successfully called with validated payload ---") + # In a real scenario, you'd process the payload here. + # For reproduction, we just need to see that it gets called correctly. + return "Tool call successful." + +# Use a model that supports tool calling, like gpt-4o +llm = ChatOpenAI(model="gpt-4o", temperature=0) + +# Bind the tool to the LLM +llm_with_tools = llm.bind_tools([query_datasoil_data_tool]) + +# --- NEW: Inspect the schema LangChain generates BEFORE the LLM call --- +tool_schemas = llm_with_tools.kwargs.get("tools", []) +print("\n--- Generated Tool Schema (for LLM) ---") +print(json.dumps(tool_schemas, indent=2)) +# --- End of new section --- + +# Example invocation +prompt = "Get the detail table for sales data from 2025-07-01 00:00:00 to 2025-07-08 00:00:00, grouped by city, and ordered by total revenue descending." + +print(f"\n--- Invoking LLM with prompt: '{prompt}' ---") + +ai_msg = llm_with_tools.invoke(prompt) + +print("\n--- LLM Response ---") +print(ai_msg) + +if isinstance(ai_msg, AIMessage) and ai_msg.tool_calls: + print("\n--- Generated Tool Call Arguments ---") + # In a real case, you'd see the arguments the LLM generated. + # The bug is that these args are often malformed due to an incorrect schema. + print(ai_msg.tool_calls[0]['args']) +else: + print("\n--- No tool call was generated ---")