mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-05 11:12:47 +00:00
test: add reproduction script for pydantic v2 nested schema bug
This commit is contained in:
parent
efdfa00d10
commit
0cbd5deaef
141
reproduce_pydanticv2_test.py
Normal file
141
reproduce_pydanticv2_test.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import re
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from typing import Literal, Optional, Tuple, Union, Annotated
|
||||||
|
from pydantic import BaseModel, Field, PositiveInt, ValidationInfo, field_validator, ConfigDict
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
# Ensure you have your OPENAI_API_KEY set as an environment variable
|
||||||
|
if not os.getenv("OPENAI_API_KEY"):
|
||||||
|
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
||||||
|
|
||||||
|
# Dummy placeholder since this isn't a real LangGraph state injection
|
||||||
|
def InjectedState(d: dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# --- Pydantic Models from the GitHub Issue ---
|
||||||
|
|
||||||
|
time_fmt = "%Y-%m-%d %H:%M:%S"
|
||||||
|
time_pattern = r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$"
|
||||||
|
|
||||||
|
# Forward-declare nested models for Pydantic
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParam:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadTimeShift(BaseModel):
|
||||||
|
shiftInterval: list[PositiveInt] = Field(description="Each element in the array represents a time offset relative to the query timestamp for individual time comparison analysis. If time comparison analysis dose not described, keep it **VOID**.",max_length=2,default=[])
|
||||||
|
timeUnit: Literal["DAY"] = Field(default="DAY",description="The unit of specific comparison time offset. This is the description about each value of unit: Unit **DAY** represents one day.")
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParamWhereFilter(BaseModel):
|
||||||
|
field: str = Field(description="The dimension **CODE** in the selected dimension list that requires enums filtering or pattern filtering.")
|
||||||
|
operator: Literal["IN", "NI", "LIKE", "NOT_LIKE"] = Field(description="Operators for enums filtering or pattern filtering.")
|
||||||
|
value: list[str] = Field(description="If for enums filtering, every element represents th practical enums of the dimension. Otherwise for pattern filtering, only **one** element is required and it represents a wildcard pattern.",min_length=1)
|
||||||
|
|
||||||
|
@field_validator("field")
|
||||||
|
def field_block(cls, v: str, info: ValidationInfo) -> str:
|
||||||
|
if v == "dt":
|
||||||
|
raise ValueError("Instruction: The time filtering should be described in 'time' field, not in the 'filters' field.")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("value")
|
||||||
|
def value_block(cls, v: Optional[list[str]], info: ValidationInfo) -> Optional[list[str]]:
|
||||||
|
if info.data.get("operator") in {"LIKE", "NOT_LIKE"} and v and len(v) > 1:
|
||||||
|
raise ValueError("Instruction: For pattern filtering, the size of 'value' in 'where' must be **ONE**.")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParamWhere(BaseModel):
|
||||||
|
time: list[Union[str, int]] = Field(description=f"The target time range...", min_length=2, max_length=2)
|
||||||
|
filters: list[DataSoilDashboardQueryPayloadQueryParamWhereFilter] = Field(description="Enums filtering or pattern filtering condition...")
|
||||||
|
relation: Literal["AND"] = Field(description="Boolean relationships between filters...")
|
||||||
|
|
||||||
|
@field_validator("time")
|
||||||
|
def time_format_block(cls, v: list[Union[int, str]], info: ValidationInfo) -> list[Union[int, str]]:
|
||||||
|
if isinstance(v[0], str) and not re.search(time_pattern, v[0]):
|
||||||
|
raise ValueError(f"Instruction: the start time of time range must be formatted as **{time_fmt}**")
|
||||||
|
if isinstance(v[1], str) and not re.search(time_pattern, v[1]):
|
||||||
|
raise ValueError(f"Instruction: the end time of time range must be formatted as **{time_fmt}**")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParamOrderBy(BaseModel):
|
||||||
|
field: str = Field(description="The metric **CODE** in the selected metric list that requires metric sorting.")
|
||||||
|
direction: Literal["ASC", "DESC"] = Field(description="Sorting direction for specified metric.")
|
||||||
|
shift: int = Field(default=0)
|
||||||
|
limit: int = Field(description="The number of rows to return...", default=50)
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParamGroupBy(BaseModel):
|
||||||
|
field: str = Field(description="The dimension **CODE** in the selected dimension list for dimension grouping analysis.")
|
||||||
|
extendFields: list[str] = Field(default=[])
|
||||||
|
orderBy: Optional[DataSoilDashboardQueryPayloadQueryParamOrderBy] = Field(description="Sorting config for query results...", default=None)
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayloadQueryParam(BaseModel):
|
||||||
|
queryType: Literal["DETAIL_TABLE"] = Field(description="This is the description about queryType...")
|
||||||
|
interval: Literal["BY_ONE_MINUTE", "BY_FIVE_MINUTE", "BY_HOUR", "BY_DAY", "BY_WEEK", "BY_MONTH", "SUM"] = Field(description="The time granularity for time-based grouping analysis.")
|
||||||
|
resultField: list[str] = Field(default=[])
|
||||||
|
where: DataSoilDashboardQueryPayloadQueryParamWhere = Field(description="Filtering condition for dimensions.")
|
||||||
|
groupBy: list[DataSoilDashboardQueryPayloadQueryParamGroupBy] = Field(description="A list of dimensions grouping analysis info...")
|
||||||
|
orderBy: DataSoilDashboardQueryPayloadQueryParamOrderBy = Field(description="Sorting config for query results...")
|
||||||
|
heavyQuery: bool = Field(default=False)
|
||||||
|
|
||||||
|
@field_validator("groupBy")
|
||||||
|
def groupBy_block(cls, v: list[DataSoilDashboardQueryPayloadQueryParamGroupBy], info: ValidationInfo) -> list[DataSoilDashboardQueryPayloadQueryParamGroupBy]:
|
||||||
|
if "dt" in {e.field for e in v}:
|
||||||
|
if info.data.get("interval") == "SUM":
|
||||||
|
raise ValueError("Instruction: the interval can not be **SUM** when **time-based grouping is required**.")
|
||||||
|
else:
|
||||||
|
if info.data.get("interval") != "SUM":
|
||||||
|
raise ValueError("Instruction: the interval must be **SUM** when **time-based grouping is not required**.")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class DataSoilDashboardQueryPayload(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=False)
|
||||||
|
apiCode: str = Field(default="")
|
||||||
|
requestId: str = Field(default="")
|
||||||
|
applicationCode: str = Field(default="")
|
||||||
|
applicationToken: str = Field(default="")
|
||||||
|
debug: bool = Field(default=False)
|
||||||
|
timeShift: DataSoilDashboardQueryPayloadTimeShift = Field(description="Time comparison analysis config.", default_factory=DataSoilDashboardQueryPayloadTimeShift)
|
||||||
|
dynamicQueryParam: DataSoilDashboardQueryPayloadQueryParam
|
||||||
|
forceFlush: bool = Field(default=False)
|
||||||
|
|
||||||
|
# Resolve forward references
|
||||||
|
DataSoilDashboardQueryPayload.model_rebuild()
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def query_datasoil_data_tool(payload: DataSoilDashboardQueryPayload) -> str:
|
||||||
|
"""Queries the DataSoil database with a complex payload."""
|
||||||
|
print("--- Tool successfully called with validated payload ---")
|
||||||
|
# In a real scenario, you'd process the payload here.
|
||||||
|
# For reproduction, we just need to see that it gets called correctly.
|
||||||
|
return "Tool call successful."
|
||||||
|
|
||||||
|
# Use a model that supports tool calling, like gpt-4o
|
||||||
|
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
||||||
|
|
||||||
|
# Bind the tool to the LLM
|
||||||
|
llm_with_tools = llm.bind_tools([query_datasoil_data_tool])
|
||||||
|
|
||||||
|
# --- NEW: Inspect the schema LangChain generates BEFORE the LLM call ---
|
||||||
|
tool_schemas = llm_with_tools.kwargs.get("tools", [])
|
||||||
|
print("\n--- Generated Tool Schema (for LLM) ---")
|
||||||
|
print(json.dumps(tool_schemas, indent=2))
|
||||||
|
# --- End of new section ---
|
||||||
|
|
||||||
|
# Example invocation
|
||||||
|
prompt = "Get the detail table for sales data from 2025-07-01 00:00:00 to 2025-07-08 00:00:00, grouped by city, and ordered by total revenue descending."
|
||||||
|
|
||||||
|
print(f"\n--- Invoking LLM with prompt: '{prompt}' ---")
|
||||||
|
|
||||||
|
ai_msg = llm_with_tools.invoke(prompt)
|
||||||
|
|
||||||
|
print("\n--- LLM Response ---")
|
||||||
|
print(ai_msg)
|
||||||
|
|
||||||
|
if isinstance(ai_msg, AIMessage) and ai_msg.tool_calls:
|
||||||
|
print("\n--- Generated Tool Call Arguments ---")
|
||||||
|
# In a real case, you'd see the arguments the LLM generated.
|
||||||
|
# The bug is that these args are often malformed due to an incorrect schema.
|
||||||
|
print(ai_msg.tool_calls[0]['args'])
|
||||||
|
else:
|
||||||
|
print("\n--- No tool call was generated ---")
|
Loading…
Reference in New Issue
Block a user