mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
langchain-robocorp: Fix parsing of Union types (such as Optional). (#22277)
This commit is contained in:
parent
af1f723ada
commit
fc5909ad6f
@ -1,5 +1,6 @@
|
|||||||
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Set, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
from langchain_core.pydantic_v1 import BaseModel, Field, create_model
|
||||||
from langchain_core.utils.json_schema import dereference_refs
|
from langchain_core.utils.json_schema import dereference_refs
|
||||||
@ -93,26 +94,45 @@ def get_schema(endpoint_spec: dict) -> dict:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_field(schema: dict, required: bool) -> Tuple[Any, Any]:
|
def create_field(
|
||||||
|
schema: dict, required: bool, created_model_names: Set[str]
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
"""
|
"""
|
||||||
Creates a Pydantic field based on the schema definition.
|
Creates a Pydantic field based on the schema definition.
|
||||||
"""
|
"""
|
||||||
|
if "anyOf" in schema:
|
||||||
|
field_types = [
|
||||||
|
create_field(sub_schema, required, created_model_names)[0]
|
||||||
|
for sub_schema in schema["anyOf"]
|
||||||
|
]
|
||||||
|
if len(field_types) == 1:
|
||||||
|
field_type = field_types[0] # Simplified handling
|
||||||
|
else:
|
||||||
|
field_type = Union[tuple(field_types)]
|
||||||
|
else:
|
||||||
field_type = type_mapping.get(schema.get("type", "string"), str)
|
field_type = type_mapping.get(schema.get("type", "string"), str)
|
||||||
|
|
||||||
description = schema.get("description", "")
|
description = schema.get("description", "")
|
||||||
|
|
||||||
# Handle nested objects
|
# Handle nested objects
|
||||||
if schema["type"] == "object":
|
if schema.get("type") == "object":
|
||||||
nested_fields = {
|
nested_fields = {
|
||||||
k: create_field(v, k in schema.get("required", []))
|
k: create_field(v, k in schema.get("required", []), created_model_names)
|
||||||
for k, v in schema.get("properties", {}).items()
|
for k, v in schema.get("properties", {}).items()
|
||||||
}
|
}
|
||||||
model_name = schema.get("title", "NestedModel")
|
model_name = schema.get("title", f"NestedModel{time.time()}")
|
||||||
|
if model_name in created_model_names:
|
||||||
|
# needs to be unique
|
||||||
|
model_name = model_name + str(time.time())
|
||||||
nested_model = create_model(model_name, **nested_fields) # type: ignore
|
nested_model = create_model(model_name, **nested_fields) # type: ignore
|
||||||
|
created_model_names.add(model_name)
|
||||||
return nested_model, Field(... if required else None, description=description)
|
return nested_model, Field(... if required else None, description=description)
|
||||||
|
|
||||||
# Handle arrays
|
# Handle arrays
|
||||||
elif schema["type"] == "array":
|
elif schema.get("type") == "array":
|
||||||
item_type, _ = create_field(schema["items"], required=True)
|
item_type, _ = create_field(
|
||||||
|
schema["items"], required=True, created_model_names=created_model_names
|
||||||
|
)
|
||||||
return List[item_type], Field( # type: ignore
|
return List[item_type], Field( # type: ignore
|
||||||
... if required else None, description=description
|
... if required else None, description=description
|
||||||
)
|
)
|
||||||
@ -128,9 +148,10 @@ def get_param_fields(endpoint_spec: dict) -> dict:
|
|||||||
required_fields = schema.get("required", [])
|
required_fields = schema.get("required", [])
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
|
created_model_names: Set[str] = set()
|
||||||
for key, value in properties.items():
|
for key, value in properties.items():
|
||||||
is_required = key in required_fields
|
is_required = key in required_fields
|
||||||
field_info = create_field(value, is_required)
|
field_info = create_field(value, is_required, created_model_names)
|
||||||
fields[key] = field_info
|
fields[key] = field_info
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
1891
libs/partners/robocorp/tests/unit_tests/_openapi3.fixture.json
Normal file
1891
libs/partners/robocorp/tests/unit_tests/_openapi3.fixture.json
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,10 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_function
|
from langchain_core.utils.function_calling import (
|
||||||
|
convert_to_openai_function,
|
||||||
|
convert_to_openai_tool,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_robocorp.toolkits import ActionServerToolkit
|
from langchain_robocorp.toolkits import ActionServerToolkit
|
||||||
|
|
||||||
@ -118,3 +121,66 @@ Strictly adhere to the schema."""
|
|||||||
],
|
],
|
||||||
}
|
}
|
||||||
assert params["properties"]["rows_to_add"] == expected
|
assert params["properties"]["rows_to_add"] == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tools_with_complex_inputs() -> None:
|
||||||
|
toolkit_instance = ActionServerToolkit(
|
||||||
|
url="http://example.com", api_key="dummy_key"
|
||||||
|
)
|
||||||
|
|
||||||
|
fixture_path = Path(__file__).with_name("_openapi3.fixture.json")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"langchain_robocorp.toolkits.requests.get"
|
||||||
|
) as mocked_get, fixture_path.open("r") as f:
|
||||||
|
data = json.load(f) # Using json.load directly on the file object
|
||||||
|
mocked_response = MagicMock()
|
||||||
|
mocked_response.json.return_value = data
|
||||||
|
mocked_response.status_code = 200
|
||||||
|
mocked_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mocked_get.return_value = mocked_response
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
tools = toolkit_instance.get_tools()
|
||||||
|
assert len(tools) == 4
|
||||||
|
|
||||||
|
tool = tools[0]
|
||||||
|
assert tool.name == "create_event"
|
||||||
|
assert tool.description == "Creates a new event in the specified calendar."
|
||||||
|
|
||||||
|
all_tools_as_openai_tools = [convert_to_openai_tool(t) for t in tools]
|
||||||
|
openai_tool_spec = all_tools_as_openai_tools[0]["function"]
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
openai_tool_spec, dict
|
||||||
|
), "openai_func_spec should be a dictionary."
|
||||||
|
assert set(openai_tool_spec.keys()) == {
|
||||||
|
"description",
|
||||||
|
"name",
|
||||||
|
"parameters",
|
||||||
|
}, "Top-level keys mismatch."
|
||||||
|
|
||||||
|
assert openai_tool_spec["description"] == tool.description
|
||||||
|
assert openai_tool_spec["name"] == tool.name
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
openai_tool_spec["parameters"], dict
|
||||||
|
), "Parameters should be a dictionary."
|
||||||
|
|
||||||
|
params = openai_tool_spec["parameters"]
|
||||||
|
assert set(params.keys()) == {
|
||||||
|
"type",
|
||||||
|
"properties",
|
||||||
|
"required",
|
||||||
|
}, "Parameters keys mismatch."
|
||||||
|
assert params["type"] == "object", "`type` in parameters should be 'object'."
|
||||||
|
assert isinstance(
|
||||||
|
params["properties"], dict
|
||||||
|
), "`properties` should be a dictionary."
|
||||||
|
assert isinstance(params["required"], list), "`required` should be a list."
|
||||||
|
|
||||||
|
assert set(params["required"]) == {
|
||||||
|
"event",
|
||||||
|
}, "Required fields mismatch."
|
||||||
|
|
||||||
|
assert set(params["properties"].keys()) == {"calendar_id", "event"}
|
||||||
|
Loading…
Reference in New Issue
Block a user