diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py index fa26b3c5d0e..35b104c4a39 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py @@ -1,56 +1,9 @@ """Quick and dirty representation for OpenAPI specs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import List, Tuple - -def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]: - """Try to substitute $refs. - - The goal is to get the complete docs for each endpoint in context for now. - - In the few OpenAPI specs I studied, $refs referenced models - (or in OpenAPI terms, components) and could be nested. This code most - likely misses lots of cases. - """ - - def _retrieve_ref_path(path: str, full_spec: dict) -> dict: - components = path.split("/") - if components[0] != "#": - raise RuntimeError( - "All $refs I've seen so far are uri fragments (start with hash)." - ) - out = full_spec - for component in components[1:]: - out = out[component] - return out - - def _dereference_refs( - obj: Union[dict, list], stop: bool = False - ) -> Union[dict, list]: - if stop: - return obj - obj_out: Dict[str, Any] = {} - if isinstance(obj, dict): - for k, v in obj.items(): - if k == "$ref": - # stop=True => don't dereference recursively. - return _dereference_refs( - _retrieve_ref_path(v, full_spec), stop=True - ) - elif isinstance(v, list): - obj_out[k] = [_dereference_refs(el) for el in v] - elif isinstance(v, dict): - obj_out[k] = _dereference_refs(v) - else: - obj_out[k] = v - return obj_out - elif isinstance(obj, list): - return [_dereference_refs(el) for el in obj] - else: - return obj - - return _dereference_refs(spec_obj) +from langchain.utils.json_schema import dereference_refs @dataclass(frozen=True) @@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS # Note: probably want to do this post-retrieval, it blows up the size of the spec. if dereference: endpoints = [ - (name, description, dereference_refs(docs, spec)) + (name, description, dereference_refs(docs, full_schema=spec)) for name, description, docs in endpoints ] diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index e023c67b7f3..84089e0cccc 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -10,6 +10,7 @@ from typing import ( Tuple, Type, Union, + cast, ) from langchain.base_language import BaseLanguageModel @@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import ( from langchain.prompts import BasePromptTemplate from langchain.pydantic_v1 import BaseModel from langchain.schema import BaseLLMOutputParser +from langchain.utils.openai_functions import convert_pydantic_to_openai_function PYTHON_TO_JSON_TYPES = { "str": "string", @@ -148,14 +150,7 @@ def convert_to_openai_function( if isinstance(function, dict): return function elif isinstance(function, type) and issubclass(function, BaseModel): - # Mypy error: - # "type" has no attribute "schema" - schema = function.schema() # type: ignore[attr-defined] - return { - "name": schema["title"], - "description": schema["description"], - "parameters": schema, - } + return cast(Dict, convert_pydantic_to_openai_function(function)) elif callable(function): return convert_python_function_to_openai_function(function) diff --git a/libs/langchain/langchain/tools/convert_to_openai.py b/libs/langchain/langchain/tools/convert_to_openai.py index e575b024d4b..3385b0d831e 100644 --- a/libs/langchain/langchain/tools/convert_to_openai.py +++ b/libs/langchain/langchain/tools/convert_to_openai.py @@ -1,41 +1,21 @@ -from typing import TypedDict - -from langchain.tools import BaseTool, StructuredTool - - -class FunctionDescription(TypedDict): - """Representation of a callable function to the OpenAI API.""" - - name: str - """The name of the function.""" - description: str - """A description of the function.""" - parameters: dict - """The parameters of the function.""" +from langchain.tools import BaseTool +from langchain.utils.openai_functions import ( + FunctionDescription, + convert_pydantic_to_openai_function, +) def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: """Format tool into the OpenAI function API.""" - if isinstance(tool, StructuredTool): - schema_ = tool.args_schema.schema() - # Bug with required missing for structured tools. - required = schema_.get( - "required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND + if tool.args_schema: + return convert_pydantic_to_openai_function( + tool.args_schema, name=tool.name, description=tool.description ) + else: return { "name": tool.name, "description": tool.description, "parameters": { - "type": "object", - "properties": schema_["properties"], - "required": required, - }, - } - else: - if tool.args_schema: - parameters = tool.args_schema.schema() - else: - parameters = { # This is a hack to get around the fact that some tools # do not expose an args_schema, and expect an argument # which is a string. @@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: }, "required": ["__arg1"], "type": "object", - } - - return { - "name": tool.name, - "description": tool.description, - "parameters": parameters, + }, } diff --git a/libs/langchain/langchain/utils/json_schema.py b/libs/langchain/langchain/utils/json_schema.py new file mode 100644 index 00000000000..c5feab8478d --- /dev/null +++ b/libs/langchain/langchain/utils/json_schema.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Optional, TypeVar, Union, cast + + +def _retrieve_ref(path: str, schema: dict) -> dict: + components = path.split("/") + if components[0] != "#": + raise ValueError( + "ref paths are expected to be URI fragments, meaning they should start " + "with #." + ) + out = schema + for component in components[1:]: + out = out[component] + return out + + +JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list]) + + +def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE: + if isinstance(obj, dict): + obj_out = {} + for k, v in obj.items(): + if k == "$ref": + ref = _retrieve_ref(v, full_schema) + obj_out[k] = _dereference_refs_helper(ref, full_schema) + elif isinstance(v, (list, dict)): + obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore + else: + obj_out[k] = v + return cast(JSON_LIKE, obj_out) + elif isinstance(obj, list): + return cast( + JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj] + ) + else: + return obj + + +def dereference_refs( + schema_obj: dict, *, full_schema: Optional[dict] = None +) -> Union[dict, list]: + """Try to substitute $refs in JSON Schema.""" + + full_schema = full_schema or schema_obj + return _dereference_refs_helper(schema_obj, full_schema)