mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
Merge 28f1c5f3c7
into 0e287763cd
This commit is contained in:
commit
4888d319c4
@ -61,38 +61,38 @@ class ToolDescription(TypedDict):
|
||||
"""The function description."""
|
||||
|
||||
|
||||
def _rm_titles(kv: dict, prev_key: str = "") -> dict:
|
||||
"""Recursively removes "title" fields from a JSON schema dictionary.
|
||||
def _rm_titles(kv: dict) -> dict:
|
||||
"""Recursively removes all "title" fields from a JSON schema dictionary.
|
||||
|
||||
Remove "title" fields from the input JSON schema dictionary,
|
||||
except when a "title" appears within a property definition under "properties".
|
||||
|
||||
Args:
|
||||
kv (dict): The input JSON schema as a dictionary.
|
||||
prev_key (str): The key from the parent dictionary, used to identify context.
|
||||
|
||||
Returns:
|
||||
dict: A new dictionary with appropriate "title" fields removed.
|
||||
This is used to remove extraneous Pydantic schema titles. It is intelligent
|
||||
enough to preserve fields that are legitimately named "title" within an
|
||||
object's properties.
|
||||
"""
|
||||
new_kv = {}
|
||||
|
||||
for k, v in kv.items():
|
||||
if k == "title":
|
||||
# If the value is a nested dict and part of a property under "properties",
|
||||
# preserve the title but continue recursion
|
||||
if isinstance(v, dict) and prev_key == "properties":
|
||||
new_kv[k] = _rm_titles(v, k)
|
||||
else:
|
||||
# Otherwise, remove this "title" key
|
||||
continue
|
||||
elif isinstance(v, dict):
|
||||
# Recurse into nested dictionaries
|
||||
new_kv[k] = _rm_titles(v, k)
|
||||
else:
|
||||
# Leave non-dict values untouched
|
||||
new_kv[k] = v
|
||||
def inner(obj: Any, *, in_properties: bool = False) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
if in_properties:
|
||||
# We are inside a 'properties' block. Keys here are valid
|
||||
# field names (e.g., "title") and should be kept. We
|
||||
# recurse on the values, resetting the flag.
|
||||
return {k: inner(v, in_properties=False) for k, v in obj.items()}
|
||||
|
||||
return new_kv
|
||||
# We are at a schema level. The 'title' key is metadata and should be
|
||||
# removed.
|
||||
out = {}
|
||||
for k, v in obj.items():
|
||||
if k == "title":
|
||||
continue
|
||||
# Recurse, setting the flag only if the key is 'properties'.
|
||||
out[k] = inner(v, in_properties=(k == "properties"))
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
# Recurse on items in a list.
|
||||
return [inner(item, in_properties=in_properties) for item in obj]
|
||||
# Return non-dict, non-list values as is.
|
||||
return obj
|
||||
|
||||
return inner(kv)
|
||||
|
||||
|
||||
def _convert_json_schema_to_openai_function(
|
||||
@ -255,6 +255,65 @@ def _convert_typed_dict_to_openai_function(typed_dict: type) -> FunctionDescript
|
||||
_MAX_TYPED_DICT_RECURSION = 25
|
||||
|
||||
|
||||
def _parse_google_docstring(
|
||||
docstring: Optional[str],
|
||||
args: list[str],
|
||||
*,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
if error_on_invalid_docstring:
|
||||
filtered_annotations = {
|
||||
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
||||
}
|
||||
if filtered_annotations and (
|
||||
len(docstring_blocks) < 2
|
||||
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
||||
):
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
if block.startswith(("Returns:", "Example:")):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
if error_on_invalid_docstring:
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg = arg.strip()
|
||||
arg_name, _, annotations_ = arg.partition(" ")
|
||||
if annotations_.startswith("(") and annotations_.endswith(")"):
|
||||
arg = arg_name
|
||||
arg_descriptions[arg] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _convert_any_typed_dicts_to_pydantic(
|
||||
type_: type,
|
||||
*,
|
||||
@ -282,18 +341,28 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
annotated_args[0], depth=depth + 1, visited=visited
|
||||
)
|
||||
field_kwargs = dict(zip(("default", "description"), annotated_args[1:]))
|
||||
field_kwargs = {}
|
||||
metadata = annotated_args[1:]
|
||||
if len(metadata) == 1 and isinstance(metadata[0], str):
|
||||
# Case: Annotated[int, "a description"]
|
||||
field_kwargs["description"] = metadata[0]
|
||||
elif len(metadata) > 0:
|
||||
# Case: Annotated[int, default_val, "a description"]
|
||||
field_kwargs["default"] = metadata[0]
|
||||
if len(metadata) > 1 and isinstance(metadata[1], str):
|
||||
field_kwargs["description"] = metadata[1]
|
||||
|
||||
if (field_desc := field_kwargs.get("description")) and not isinstance(
|
||||
field_desc, str
|
||||
):
|
||||
msg = (
|
||||
f"Invalid annotation for field {arg}. Third argument to "
|
||||
f"Annotated must be a string description, received value of "
|
||||
f"type {type(field_desc)}."
|
||||
f"Invalid annotation for field {arg}. "
|
||||
"Description must be a string."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
if arg_desc := arg_descriptions.get(arg):
|
||||
field_kwargs["description"] = arg_desc
|
||||
|
||||
fields[arg] = (new_arg_type, Field_v1(**field_kwargs))
|
||||
else:
|
||||
new_arg_type = _convert_any_typed_dicts_to_pydantic(
|
||||
@ -317,6 +386,25 @@ def _convert_any_typed_dicts_to_pydantic(
|
||||
return type_
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: type) -> type:
|
||||
origin_union_type_map: dict[type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: dict,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
collections.abc.Iterable: typing.Iterable,
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast("type", origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API.
|
||||
|
||||
@ -386,6 +474,30 @@ def format_tool_to_openai_tool(tool: BaseTool) -> ToolDescription:
|
||||
return {"type": "function", "function": function}
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||
# in which case additionalProperties still needs to be specified.
|
||||
if "required" in schema or (
|
||||
"properties" in schema and not schema["properties"]
|
||||
):
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Recursively check 'properties' and 'items' if they exist
|
||||
if "anyOf" in schema:
|
||||
for sub_schema in schema["anyOf"]:
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "properties" in schema:
|
||||
for sub_schema in schema["properties"].values():
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "items" in schema:
|
||||
_recursive_set_additional_properties_false(schema["items"])
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def convert_to_openai_function(
|
||||
function: Union[dict[str, Any], type, Callable, BaseTool],
|
||||
*,
|
||||
@ -717,105 +829,3 @@ def tool_example_to_messages(
|
||||
if ai_response:
|
||||
messages.append(AIMessage(content=ai_response))
|
||||
return messages
|
||||
|
||||
|
||||
def _parse_google_docstring(
|
||||
docstring: Optional[str],
|
||||
args: list[str],
|
||||
*,
|
||||
error_on_invalid_docstring: bool = False,
|
||||
) -> tuple[str, dict]:
|
||||
"""Parse the function and argument descriptions from the docstring of a function.
|
||||
|
||||
Assumes the function docstring follows Google Python style guide.
|
||||
"""
|
||||
if docstring:
|
||||
docstring_blocks = docstring.split("\n\n")
|
||||
if error_on_invalid_docstring:
|
||||
filtered_annotations = {
|
||||
arg for arg in args if arg not in {"run_manager", "callbacks", "return"}
|
||||
}
|
||||
if filtered_annotations and (
|
||||
len(docstring_blocks) < 2
|
||||
or not any(block.startswith("Args:") for block in docstring_blocks[1:])
|
||||
):
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
descriptors = []
|
||||
args_block = None
|
||||
past_descriptors = False
|
||||
for block in docstring_blocks:
|
||||
if block.startswith("Args:"):
|
||||
args_block = block
|
||||
break
|
||||
if block.startswith(("Returns:", "Example:")):
|
||||
# Don't break in case Args come after
|
||||
past_descriptors = True
|
||||
elif not past_descriptors:
|
||||
descriptors.append(block)
|
||||
else:
|
||||
continue
|
||||
description = " ".join(descriptors)
|
||||
else:
|
||||
if error_on_invalid_docstring:
|
||||
msg = "Found invalid Google-Style docstring."
|
||||
raise ValueError(msg)
|
||||
description = ""
|
||||
args_block = None
|
||||
arg_descriptions = {}
|
||||
if args_block:
|
||||
arg = None
|
||||
for line in args_block.split("\n")[1:]:
|
||||
if ":" in line:
|
||||
arg, desc = line.split(":", maxsplit=1)
|
||||
arg = arg.strip()
|
||||
arg_name, _, annotations_ = arg.partition(" ")
|
||||
if annotations_.startswith("(") and annotations_.endswith(")"):
|
||||
arg = arg_name
|
||||
arg_descriptions[arg] = desc.strip()
|
||||
elif arg:
|
||||
arg_descriptions[arg] += " " + line.strip()
|
||||
return description, arg_descriptions
|
||||
|
||||
|
||||
def _py_38_safe_origin(origin: type) -> type:
|
||||
origin_union_type_map: dict[type, Any] = (
|
||||
{types.UnionType: Union} if hasattr(types, "UnionType") else {}
|
||||
)
|
||||
|
||||
origin_map: dict[type, Any] = {
|
||||
dict: dict,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
collections.abc.Iterable: typing.Iterable,
|
||||
collections.abc.Mapping: typing.Mapping,
|
||||
collections.abc.Sequence: typing.Sequence,
|
||||
collections.abc.MutableMapping: typing.MutableMapping,
|
||||
**origin_union_type_map,
|
||||
}
|
||||
return cast("type", origin_map.get(origin, origin))
|
||||
|
||||
|
||||
def _recursive_set_additional_properties_false(
|
||||
schema: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(schema, dict):
|
||||
# Check if 'required' is a key at the current level or if the schema is empty,
|
||||
# in which case additionalProperties still needs to be specified.
|
||||
if "required" in schema or (
|
||||
"properties" in schema and not schema["properties"]
|
||||
):
|
||||
schema["additionalProperties"] = False
|
||||
|
||||
# Recursively check 'properties' and 'items' if they exist
|
||||
if "anyOf" in schema:
|
||||
for sub_schema in schema["anyOf"]:
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "properties" in schema:
|
||||
for sub_schema in schema["properties"].values():
|
||||
_recursive_set_additional_properties_false(sub_schema)
|
||||
if "items" in schema:
|
||||
_recursive_set_additional_properties_false(schema["items"])
|
||||
|
||||
return schema
|
||||
|
@ -35,6 +35,17 @@ from langchain_core.utils.function_calling import (
|
||||
)
|
||||
|
||||
|
||||
def remove_titles(obj: dict) -> None:
|
||||
if isinstance(obj, dict):
|
||||
obj.pop("title", None)
|
||||
for v in obj.values():
|
||||
remove_titles(v)
|
||||
elif isinstance(obj, list):
|
||||
for v in obj:
|
||||
remove_titles(v)
|
||||
return obj
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pydantic() -> type[BaseModel]:
|
||||
class dummy_function(BaseModel): # noqa: N801
|
||||
@ -365,9 +376,9 @@ def test_convert_to_openai_function(
|
||||
dummy_extensions_typed_dict_docstring,
|
||||
):
|
||||
actual = convert_to_openai_function(fn)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
# Test runnables
|
||||
actual = convert_to_openai_function(runnable.as_tool(description="Dummy function."))
|
||||
parameters = {
|
||||
"type": "object",
|
||||
@ -384,7 +395,6 @@ def test_convert_to_openai_function(
|
||||
runnable_expected["parameters"] = parameters
|
||||
assert actual == runnable_expected
|
||||
|
||||
# Test simple Tool
|
||||
def my_function(_: str) -> str:
|
||||
return ""
|
||||
|
||||
@ -398,11 +408,12 @@ def test_convert_to_openai_function(
|
||||
"name": "dummy_function",
|
||||
"description": "test description",
|
||||
"parameters": {
|
||||
"properties": {"__arg1": {"title": "__arg1", "type": "string"}},
|
||||
"properties": {"__arg1": {"type": "string"}},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
},
|
||||
}
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -454,6 +465,7 @@ def test_convert_to_openai_function_nested() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -494,6 +506,7 @@ def test_convert_to_openai_function_nested_strict() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function, strict=True)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -518,23 +531,20 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
||||
"my_arg": {
|
||||
"anyOf": [
|
||||
{
|
||||
"properties": {"foo": {"title": "Foo", "type": "string"}},
|
||||
"properties": {"foo": {"type": "string"}},
|
||||
"required": ["foo"],
|
||||
"title": "NestedA",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
{
|
||||
"properties": {"bar": {"title": "Bar", "type": "integer"}},
|
||||
"properties": {"bar": {"type": "integer"}},
|
||||
"required": ["bar"],
|
||||
"title": "NestedB",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
{
|
||||
"properties": {"baz": {"title": "Baz", "type": "boolean"}},
|
||||
"properties": {"baz": {"type": "boolean"}},
|
||||
"required": ["baz"],
|
||||
"title": "NestedC",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
},
|
||||
@ -549,6 +559,7 @@ def test_convert_to_openai_function_strict_union_of_objects_arg_type() -> None:
|
||||
}
|
||||
|
||||
actual = convert_to_openai_function(my_function, strict=True)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -556,7 +567,6 @@ json_schema_no_description_no_params = {
|
||||
"title": "dummy_function",
|
||||
}
|
||||
|
||||
|
||||
json_schema_no_description = {
|
||||
"title": "dummy_function",
|
||||
"type": "object",
|
||||
@ -571,7 +581,6 @@ json_schema_no_description = {
|
||||
"required": ["arg1", "arg2"],
|
||||
}
|
||||
|
||||
|
||||
anthropic_tool_no_description = {
|
||||
"name": "dummy_function",
|
||||
"input_schema": {
|
||||
@ -588,7 +597,6 @@ anthropic_tool_no_description = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
bedrock_converse_tool_no_description = {
|
||||
"toolSpec": {
|
||||
"name": "dummy_function",
|
||||
@ -609,7 +617,6 @@ bedrock_converse_tool_no_description = {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
openai_function_no_description = {
|
||||
"name": "dummy_function",
|
||||
"parameters": {
|
||||
@ -626,7 +633,6 @@ openai_function_no_description = {
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
openai_function_no_description_no_params = {
|
||||
"name": "dummy_function",
|
||||
}
|
||||
@ -658,6 +664,7 @@ def test_convert_to_openai_function_no_description(func: dict) -> None:
|
||||
},
|
||||
}
|
||||
actual = convert_to_openai_function(func)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -772,7 +779,6 @@ def test_tool_outputs() -> None:
|
||||
]
|
||||
assert messages[2].content == "Output1"
|
||||
|
||||
# Test final AI response
|
||||
messages = tool_example_to_messages(
|
||||
input="This is an example",
|
||||
tool_calls=[
|
||||
@ -880,12 +886,10 @@ def test__convert_typed_dict_to_openai_function(
|
||||
"items": [
|
||||
{"type": "array", "items": {}},
|
||||
{
|
||||
"title": "SubTool",
|
||||
"description": "Subtool docstring.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"args": {
|
||||
"title": "Args",
|
||||
"description": "this does bar",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
@ -916,12 +920,10 @@ def test__convert_typed_dict_to_openai_function(
|
||||
"maxItems": 1,
|
||||
"items": [
|
||||
{
|
||||
"title": "SubTool",
|
||||
"description": "Subtool docstring.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"args": {
|
||||
"title": "Args",
|
||||
"description": "this does bar",
|
||||
"default": {},
|
||||
"type": "object",
|
||||
@ -1034,6 +1036,7 @@ def test__convert_typed_dict_to_openai_function(
|
||||
},
|
||||
}
|
||||
actual = _convert_typed_dict_to_openai_function(Tool)
|
||||
remove_titles(actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -1042,7 +1045,6 @@ def test__convert_typed_dict_to_openai_function_fail(typed_dict: type) -> None:
|
||||
class Tool(typed_dict): # type: ignore[misc]
|
||||
arg1: typing.MutableSet # Pydantic 2 supports this, but pydantic v1 does not.
|
||||
|
||||
# Error should be raised since we're using v1 code path here
|
||||
with pytest.raises(TypeError):
|
||||
_convert_typed_dict_to_openai_function(Tool)
|
||||
|
||||
|
141
reproduce_pydanticv2_test.py
Normal file
141
reproduce_pydanticv2_test.py
Normal file
@ -0,0 +1,141 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
from typing import Literal, Optional, Tuple, Union, Annotated
|
||||
from pydantic import BaseModel, Field, PositiveInt, ValidationInfo, field_validator, ConfigDict
|
||||
from langchain_core.tools import tool
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# Ensure you have your OPENAI_API_KEY set as an environment variable
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
||||
|
||||
# Dummy placeholder since this isn't a real LangGraph state injection
|
||||
def InjectedState(d: dict):
|
||||
return {}
|
||||
|
||||
# --- Pydantic Models from the GitHub Issue ---
|
||||
|
||||
time_fmt = "%Y-%m-%d %H:%M:%S"
|
||||
time_pattern = r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$"
|
||||
|
||||
# Forward-declare nested models for Pydantic
|
||||
class DataSoilDashboardQueryPayloadQueryParam:
|
||||
pass
|
||||
|
||||
class DataSoilDashboardQueryPayloadTimeShift(BaseModel):
|
||||
shiftInterval: list[PositiveInt] = Field(description="Each element in the array represents a time offset relative to the query timestamp for individual time comparison analysis. If time comparison analysis dose not described, keep it **VOID**.",max_length=2,default=[])
|
||||
timeUnit: Literal["DAY"] = Field(default="DAY",description="The unit of specific comparison time offset. This is the description about each value of unit: Unit **DAY** represents one day.")
|
||||
|
||||
class DataSoilDashboardQueryPayloadQueryParamWhereFilter(BaseModel):
|
||||
field: str = Field(description="The dimension **CODE** in the selected dimension list that requires enums filtering or pattern filtering.")
|
||||
operator: Literal["IN", "NI", "LIKE", "NOT_LIKE"] = Field(description="Operators for enums filtering or pattern filtering.")
|
||||
value: list[str] = Field(description="If for enums filtering, every element represents th practical enums of the dimension. Otherwise for pattern filtering, only **one** element is required and it represents a wildcard pattern.",min_length=1)
|
||||
|
||||
@field_validator("field")
|
||||
def field_block(cls, v: str, info: ValidationInfo) -> str:
|
||||
if v == "dt":
|
||||
raise ValueError("Instruction: The time filtering should be described in 'time' field, not in the 'filters' field.")
|
||||
return v
|
||||
|
||||
@field_validator("value")
|
||||
def value_block(cls, v: Optional[list[str]], info: ValidationInfo) -> Optional[list[str]]:
|
||||
if info.data.get("operator") in {"LIKE", "NOT_LIKE"} and v and len(v) > 1:
|
||||
raise ValueError("Instruction: For pattern filtering, the size of 'value' in 'where' must be **ONE**.")
|
||||
return v
|
||||
|
||||
class DataSoilDashboardQueryPayloadQueryParamWhere(BaseModel):
|
||||
time: list[Union[str, int]] = Field(description=f"The target time range...", min_length=2, max_length=2)
|
||||
filters: list[DataSoilDashboardQueryPayloadQueryParamWhereFilter] = Field(description="Enums filtering or pattern filtering condition...")
|
||||
relation: Literal["AND"] = Field(description="Boolean relationships between filters...")
|
||||
|
||||
@field_validator("time")
|
||||
def time_format_block(cls, v: list[Union[int, str]], info: ValidationInfo) -> list[Union[int, str]]:
|
||||
if isinstance(v[0], str) and not re.search(time_pattern, v[0]):
|
||||
raise ValueError(f"Instruction: the start time of time range must be formatted as **{time_fmt}**")
|
||||
if isinstance(v[1], str) and not re.search(time_pattern, v[1]):
|
||||
raise ValueError(f"Instruction: the end time of time range must be formatted as **{time_fmt}**")
|
||||
return v
|
||||
|
||||
class DataSoilDashboardQueryPayloadQueryParamOrderBy(BaseModel):
|
||||
field: str = Field(description="The metric **CODE** in the selected metric list that requires metric sorting.")
|
||||
direction: Literal["ASC", "DESC"] = Field(description="Sorting direction for specified metric.")
|
||||
shift: int = Field(default=0)
|
||||
limit: int = Field(description="The number of rows to return...", default=50)
|
||||
|
||||
class DataSoilDashboardQueryPayloadQueryParamGroupBy(BaseModel):
|
||||
field: str = Field(description="The dimension **CODE** in the selected dimension list for dimension grouping analysis.")
|
||||
extendFields: list[str] = Field(default=[])
|
||||
orderBy: Optional[DataSoilDashboardQueryPayloadQueryParamOrderBy] = Field(description="Sorting config for query results...", default=None)
|
||||
|
||||
class DataSoilDashboardQueryPayloadQueryParam(BaseModel):
|
||||
queryType: Literal["DETAIL_TABLE"] = Field(description="This is the description about queryType...")
|
||||
interval: Literal["BY_ONE_MINUTE", "BY_FIVE_MINUTE", "BY_HOUR", "BY_DAY", "BY_WEEK", "BY_MONTH", "SUM"] = Field(description="The time granularity for time-based grouping analysis.")
|
||||
resultField: list[str] = Field(default=[])
|
||||
where: DataSoilDashboardQueryPayloadQueryParamWhere = Field(description="Filtering condition for dimensions.")
|
||||
groupBy: list[DataSoilDashboardQueryPayloadQueryParamGroupBy] = Field(description="A list of dimensions grouping analysis info...")
|
||||
orderBy: DataSoilDashboardQueryPayloadQueryParamOrderBy = Field(description="Sorting config for query results...")
|
||||
heavyQuery: bool = Field(default=False)
|
||||
|
||||
@field_validator("groupBy")
|
||||
def groupBy_block(cls, v: list[DataSoilDashboardQueryPayloadQueryParamGroupBy], info: ValidationInfo) -> list[DataSoilDashboardQueryPayloadQueryParamGroupBy]:
|
||||
if "dt" in {e.field for e in v}:
|
||||
if info.data.get("interval") == "SUM":
|
||||
raise ValueError("Instruction: the interval can not be **SUM** when **time-based grouping is required**.")
|
||||
else:
|
||||
if info.data.get("interval") != "SUM":
|
||||
raise ValueError("Instruction: the interval must be **SUM** when **time-based grouping is not required**.")
|
||||
return v
|
||||
|
||||
class DataSoilDashboardQueryPayload(BaseModel):
|
||||
model_config = ConfigDict(frozen=False)
|
||||
apiCode: str = Field(default="")
|
||||
requestId: str = Field(default="")
|
||||
applicationCode: str = Field(default="")
|
||||
applicationToken: str = Field(default="")
|
||||
debug: bool = Field(default=False)
|
||||
timeShift: DataSoilDashboardQueryPayloadTimeShift = Field(description="Time comparison analysis config.", default_factory=DataSoilDashboardQueryPayloadTimeShift)
|
||||
dynamicQueryParam: DataSoilDashboardQueryPayloadQueryParam
|
||||
forceFlush: bool = Field(default=False)
|
||||
|
||||
# Resolve forward references
|
||||
DataSoilDashboardQueryPayload.model_rebuild()
|
||||
|
||||
@tool
|
||||
def query_datasoil_data_tool(payload: DataSoilDashboardQueryPayload) -> str:
|
||||
"""Queries the DataSoil database with a complex payload."""
|
||||
print("--- Tool successfully called with validated payload ---")
|
||||
# In a real scenario, you'd process the payload here.
|
||||
# For reproduction, we just need to see that it gets called correctly.
|
||||
return "Tool call successful."
|
||||
|
||||
# Use a model that supports tool calling, like gpt-4o
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||
|
||||
# Bind the tool to the LLM
|
||||
llm_with_tools = llm.bind_tools([query_datasoil_data_tool])
|
||||
|
||||
# --- NEW: Inspect the schema LangChain generates BEFORE the LLM call ---
|
||||
tool_schemas = llm_with_tools.kwargs.get("tools", [])
|
||||
print("\n--- Generated Tool Schema (for LLM) ---")
|
||||
print(json.dumps(tool_schemas, indent=2))
|
||||
# --- End of new section ---
|
||||
|
||||
# Example invocation
|
||||
prompt = "Get the detail table for sales data from 2025-07-01 00:00:00 to 2025-07-08 00:00:00, grouped by city, and ordered by total revenue descending."
|
||||
|
||||
print(f"\n--- Invoking LLM with prompt: '{prompt}' ---")
|
||||
|
||||
ai_msg = llm_with_tools.invoke(prompt)
|
||||
|
||||
print("\n--- LLM Response ---")
|
||||
print(ai_msg)
|
||||
|
||||
if isinstance(ai_msg, AIMessage) and ai_msg.tool_calls:
|
||||
print("\n--- Generated Tool Call Arguments ---")
|
||||
# In a real case, you'd see the arguments the LLM generated.
|
||||
# The bug is that these args are often malformed due to an incorrect schema.
|
||||
print(ai_msg.tool_calls[0]['args'])
|
||||
else:
|
||||
print("\n--- No tool call was generated ---")
|
Loading…
Reference in New Issue
Block a user