From 240cc289e6b953ab8149aa6ab53a809a64c7989b Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 30 Aug 2023 13:37:39 -0700 Subject: [PATCH 1/3] wip --- .../agents/agent_toolkits/openapi/spec.py | 53 ++----------------- .../langchain/chains/openai_functions/base.py | 11 ++-- .../langchain/tools/convert_to_openai.py | 45 ++++------------ libs/langchain/langchain/utils/json_schema.py | 48 +++++++++++++++++ 4 files changed, 64 insertions(+), 93 deletions(-) create mode 100644 libs/langchain/langchain/utils/json_schema.py diff --git a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py index fa26b3c5d0e..35b104c4a39 100644 --- a/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py +++ b/libs/langchain/langchain/agents/agent_toolkits/openapi/spec.py @@ -1,56 +1,9 @@ """Quick and dirty representation for OpenAPI specs.""" from dataclasses import dataclass -from typing import Any, Dict, List, Tuple, Union +from typing import List, Tuple - -def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]: - """Try to substitute $refs. - - The goal is to get the complete docs for each endpoint in context for now. - - In the few OpenAPI specs I studied, $refs referenced models - (or in OpenAPI terms, components) and could be nested. This code most - likely misses lots of cases. - """ - - def _retrieve_ref_path(path: str, full_spec: dict) -> dict: - components = path.split("/") - if components[0] != "#": - raise RuntimeError( - "All $refs I've seen so far are uri fragments (start with hash)." - ) - out = full_spec - for component in components[1:]: - out = out[component] - return out - - def _dereference_refs( - obj: Union[dict, list], stop: bool = False - ) -> Union[dict, list]: - if stop: - return obj - obj_out: Dict[str, Any] = {} - if isinstance(obj, dict): - for k, v in obj.items(): - if k == "$ref": - # stop=True => don't dereference recursively. - return _dereference_refs( - _retrieve_ref_path(v, full_spec), stop=True - ) - elif isinstance(v, list): - obj_out[k] = [_dereference_refs(el) for el in v] - elif isinstance(v, dict): - obj_out[k] = _dereference_refs(v) - else: - obj_out[k] = v - return obj_out - elif isinstance(obj, list): - return [_dereference_refs(el) for el in obj] - else: - return obj - - return _dereference_refs(spec_obj) +from langchain.utils.json_schema import dereference_refs @dataclass(frozen=True) @@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS # Note: probably want to do this post-retrieval, it blows up the size of the spec. if dereference: endpoints = [ - (name, description, dereference_refs(docs, spec)) + (name, description, dereference_refs(docs, full_schema=spec)) for name, description, docs in endpoints ] diff --git a/libs/langchain/langchain/chains/openai_functions/base.py b/libs/langchain/langchain/chains/openai_functions/base.py index e023c67b7f3..84089e0cccc 100644 --- a/libs/langchain/langchain/chains/openai_functions/base.py +++ b/libs/langchain/langchain/chains/openai_functions/base.py @@ -10,6 +10,7 @@ from typing import ( Tuple, Type, Union, + cast, ) from langchain.base_language import BaseLanguageModel @@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import ( from langchain.prompts import BasePromptTemplate from langchain.pydantic_v1 import BaseModel from langchain.schema import BaseLLMOutputParser +from langchain.utils.openai_functions import convert_pydantic_to_openai_function PYTHON_TO_JSON_TYPES = { "str": "string", @@ -148,14 +150,7 @@ def convert_to_openai_function( if isinstance(function, dict): return function elif isinstance(function, type) and issubclass(function, BaseModel): - # Mypy error: - # "type" has no attribute "schema" - schema = function.schema() # type: ignore[attr-defined] - return { - "name": schema["title"], - "description": schema["description"], - "parameters": schema, - } + return cast(Dict, convert_pydantic_to_openai_function(function)) elif callable(function): return convert_python_function_to_openai_function(function) diff --git a/libs/langchain/langchain/tools/convert_to_openai.py b/libs/langchain/langchain/tools/convert_to_openai.py index e575b024d4b..3385b0d831e 100644 --- a/libs/langchain/langchain/tools/convert_to_openai.py +++ b/libs/langchain/langchain/tools/convert_to_openai.py @@ -1,41 +1,21 @@ -from typing import TypedDict - -from langchain.tools import BaseTool, StructuredTool - - -class FunctionDescription(TypedDict): - """Representation of a callable function to the OpenAI API.""" - - name: str - """The name of the function.""" - description: str - """A description of the function.""" - parameters: dict - """The parameters of the function.""" +from langchain.tools import BaseTool +from langchain.utils.openai_functions import ( + FunctionDescription, + convert_pydantic_to_openai_function, +) def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: """Format tool into the OpenAI function API.""" - if isinstance(tool, StructuredTool): - schema_ = tool.args_schema.schema() - # Bug with required missing for structured tools. - required = schema_.get( - "required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND + if tool.args_schema: + return convert_pydantic_to_openai_function( + tool.args_schema, name=tool.name, description=tool.description ) + else: return { "name": tool.name, "description": tool.description, "parameters": { - "type": "object", - "properties": schema_["properties"], - "required": required, - }, - } - else: - if tool.args_schema: - parameters = tool.args_schema.schema() - else: - parameters = { # This is a hack to get around the fact that some tools # do not expose an args_schema, and expect an argument # which is a string. @@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: }, "required": ["__arg1"], "type": "object", - } - - return { - "name": tool.name, - "description": tool.description, - "parameters": parameters, + }, } diff --git a/libs/langchain/langchain/utils/json_schema.py b/libs/langchain/langchain/utils/json_schema.py new file mode 100644 index 00000000000..c5feab8478d --- /dev/null +++ b/libs/langchain/langchain/utils/json_schema.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Optional, TypeVar, Union, cast + + +def _retrieve_ref(path: str, schema: dict) -> dict: + components = path.split("/") + if components[0] != "#": + raise ValueError( + "ref paths are expected to be URI fragments, meaning they should start " + "with #." + ) + out = schema + for component in components[1:]: + out = out[component] + return out + + +JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list]) + + +def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE: + if isinstance(obj, dict): + obj_out = {} + for k, v in obj.items(): + if k == "$ref": + ref = _retrieve_ref(v, full_schema) + obj_out[k] = _dereference_refs_helper(ref, full_schema) + elif isinstance(v, (list, dict)): + obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore + else: + obj_out[k] = v + return cast(JSON_LIKE, obj_out) + elif isinstance(obj, list): + return cast( + JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj] + ) + else: + return obj + + +def dereference_refs( + schema_obj: dict, *, full_schema: Optional[dict] = None +) -> Union[dict, list]: + """Try to substitute $refs in JSON Schema.""" + + full_schema = full_schema or schema_obj + return _dereference_refs_helper(schema_obj, full_schema) From 1f5c579ef4de385c14306f8b0ec539e39a36e432 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 30 Aug 2023 13:37:50 -0700 Subject: [PATCH 2/3] add --- .../langchain/utils/openai_functions.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 libs/langchain/langchain/utils/openai_functions.py diff --git a/libs/langchain/langchain/utils/openai_functions.py b/libs/langchain/langchain/utils/openai_functions.py new file mode 100644 index 00000000000..48c49541dcf --- /dev/null +++ b/libs/langchain/langchain/utils/openai_functions.py @@ -0,0 +1,29 @@ +from typing import Dict, Optional, Type, TypedDict, cast + +from langchain.pydantic_v1 import BaseModel +from langchain.utils.json_schema import dereference_refs + + +class FunctionDescription(TypedDict): + """Representation of a callable function to the OpenAI API.""" + + name: str + """The name of the function.""" + description: str + """A description of the function.""" + parameters: dict + """The parameters of the function.""" + + +def convert_pydantic_to_openai_function( + model: Type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None +) -> FunctionDescription: + schema = cast(Dict, dereference_refs(model.schema())) + return { + "name": name or schema["title"], + "description": description or schema["description"], + "parameters": schema, + } From e805f8e26373b24431401f02ce1a4654cb2d2078 Mon Sep 17 00:00:00 2001 From: Bagatur Date: Wed, 30 Aug 2023 15:23:02 -0700 Subject: [PATCH 3/3] add tests --- libs/langchain/langchain/utils/json_schema.py | 56 +++++-- .../langchain/utils/openai_functions.py | 5 +- .../unit_tests/utils/test_json_schema.py | 151 ++++++++++++++++++ .../unit_tests/utils/test_openai_functions.py | 79 +++++++++ 4 files changed, 273 insertions(+), 18 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/utils/test_json_schema.py create mode 100644 libs/langchain/tests/unit_tests/utils/test_openai_functions.py diff --git a/libs/langchain/langchain/utils/json_schema.py b/libs/langchain/langchain/utils/json_schema.py index c5feab8478d..9628f9e521b 100644 --- a/libs/langchain/langchain/utils/json_schema.py +++ b/libs/langchain/langchain/utils/json_schema.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Optional, TypeVar, Union, cast +from copy import deepcopy +from typing import Any, List, Optional, Sequence def _retrieve_ref(path: str, schema: dict) -> dict: @@ -13,36 +14,59 @@ def _retrieve_ref(path: str, schema: dict) -> dict: out = schema for component in components[1:]: out = out[component] - return out + return deepcopy(out) -JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list]) - - -def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE: +def _dereference_refs_helper( + obj: Any, full_schema: dict, skip_keys: Sequence[str] +) -> Any: if isinstance(obj, dict): obj_out = {} for k, v in obj.items(): - if k == "$ref": + if k in skip_keys: + obj_out[k] = v + elif k == "$ref": ref = _retrieve_ref(v, full_schema) - obj_out[k] = _dereference_refs_helper(ref, full_schema) + return _dereference_refs_helper(ref, full_schema, skip_keys) elif isinstance(v, (list, dict)): - obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore + obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys) else: obj_out[k] = v - return cast(JSON_LIKE, obj_out) + return obj_out elif isinstance(obj, list): - return cast( - JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj] - ) + return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj] else: return obj +def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]: + keys = [] + if isinstance(obj, dict): + for k, v in obj.items(): + if k == "$ref": + ref = _retrieve_ref(v, full_schema) + keys.append(v.split("/")[1]) + keys += _infer_skip_keys(ref, full_schema) + elif isinstance(v, (list, dict)): + keys += _infer_skip_keys(v, full_schema) + elif isinstance(obj, list): + for el in obj: + keys += _infer_skip_keys(el, full_schema) + return keys + + def dereference_refs( - schema_obj: dict, *, full_schema: Optional[dict] = None -) -> Union[dict, list]: + schema_obj: dict, + *, + full_schema: Optional[dict] = None, + skip_keys: Optional[Sequence[str]] = None, +) -> dict: """Try to substitute $refs in JSON Schema.""" full_schema = full_schema or schema_obj - return _dereference_refs_helper(schema_obj, full_schema) + skip_keys = ( + skip_keys + if skip_keys is not None + else _infer_skip_keys(schema_obj, full_schema) + ) + return _dereference_refs_helper(schema_obj, full_schema, skip_keys) diff --git a/libs/langchain/langchain/utils/openai_functions.py b/libs/langchain/langchain/utils/openai_functions.py index 48c49541dcf..cfb1e76d595 100644 --- a/libs/langchain/langchain/utils/openai_functions.py +++ b/libs/langchain/langchain/utils/openai_functions.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Type, TypedDict, cast +from typing import Optional, Type, TypedDict from langchain.pydantic_v1 import BaseModel from langchain.utils.json_schema import dereference_refs @@ -21,7 +21,8 @@ def convert_pydantic_to_openai_function( name: Optional[str] = None, description: Optional[str] = None ) -> FunctionDescription: - schema = cast(Dict, dereference_refs(model.schema())) + schema = dereference_refs(model.schema()) + schema.pop("definitions", None) return { "name": name or schema["title"], "description": description or schema["description"], diff --git a/libs/langchain/tests/unit_tests/utils/test_json_schema.py b/libs/langchain/tests/unit_tests/utils/test_json_schema.py new file mode 100644 index 00000000000..233c4672729 --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_json_schema.py @@ -0,0 +1,151 @@ +import pytest + +from langchain.utils.json_schema import dereference_refs + + +def test_dereference_refs_no_refs() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + }, + } + actual = dereference_refs(schema) + assert actual == schema + + +def test_dereference_refs_one_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + }, + "$defs": {"name": {"type": "string"}}, + } + expected = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + }, + "$defs": {"name": {"type": "string"}}, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_multiple_refs() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + "other": {"$ref": "#/$defs/other"}, + }, + "$defs": { + "name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + } + expected = { + "type": "object", + "properties": { + "first_name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + "$defs": { + "name": {"type": "string"}, + "other": {"type": "object", "properties": {"age": "int", "height": "int"}}, + }, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_nested_refs_skip() -> None: + schema = { + "type": "object", + "properties": { + "info": {"$ref": "#/$defs/info"}, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + expected = { + "type": "object", + "properties": { + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + actual = dereference_refs(schema) + assert actual == expected + + +def test_dereference_refs_nested_refs_no_skip() -> None: + schema = { + "type": "object", + "properties": { + "info": {"$ref": "#/$defs/info"}, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"$ref": "#/$defs/name"}}, + }, + }, + } + expected = { + "type": "object", + "properties": { + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + "$defs": { + "name": {"type": "string"}, + "info": { + "type": "object", + "properties": {"age": "int", "name": {"type": "string"}}, + }, + }, + } + actual = dereference_refs(schema, skip_keys=()) + assert actual == expected + + +def test_dereference_refs_missing_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "#/$defs/name"}, + }, + "$defs": {}, + } + with pytest.raises(KeyError): + dereference_refs(schema) + + +def test_dereference_refs_remote_ref() -> None: + schema = { + "type": "object", + "properties": { + "first_name": {"$ref": "https://somewhere/else/name"}, + }, + } + with pytest.raises(ValueError): + dereference_refs(schema) diff --git a/libs/langchain/tests/unit_tests/utils/test_openai_functions.py b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py new file mode 100644 index 00000000000..b5a22d837b9 --- /dev/null +++ b/libs/langchain/tests/unit_tests/utils/test_openai_functions.py @@ -0,0 +1,79 @@ +from langchain.pydantic_v1 import BaseModel, Field +from langchain.utils.openai_functions import convert_pydantic_to_openai_function + + +def test_convert_pydantic_to_openai_function() -> None: + class Data(BaseModel): + """The data to return.""" + + key: str = Field(..., description="API key") + days: int = Field(default=0, description="Number of days to forecast") + + actual = convert_pydantic_to_openai_function(Data) + expected = { + "name": "Data", + "description": "The data to return.", + "parameters": { + "title": "Data", + "description": "The data to return.", + "type": "object", + "properties": { + "key": {"title": "Key", "description": "API key", "type": "string"}, + "days": { + "title": "Days", + "description": "Number of days to forecast", + "default": 0, + "type": "integer", + }, + }, + "required": ["key"], + }, + } + assert actual == expected + + +def test_convert_pydantic_to_openai_function_nested() -> None: + class Data(BaseModel): + """The data to return.""" + + key: str = Field(..., description="API key") + days: int = Field(default=0, description="Number of days to forecast") + + class Model(BaseModel): + """The model to return.""" + + data: Data + + actual = convert_pydantic_to_openai_function(Model) + expected = { + "name": "Model", + "description": "The model to return.", + "parameters": { + "title": "Model", + "description": "The model to return.", + "type": "object", + "properties": { + "data": { + "title": "Data", + "description": "The data to return.", + "type": "object", + "properties": { + "key": { + "title": "Key", + "description": "API key", + "type": "string", + }, + "days": { + "title": "Days", + "description": "Number of days to forecast", + "default": 0, + "type": "integer", + }, + }, + "required": ["key"], + } + }, + "required": ["data"], + }, + } + assert actual == expected