This commit is contained in:
Bagatur 2023-08-30 13:37:39 -07:00
parent 7fa82900cb
commit 240cc289e6
4 changed files with 64 additions and 93 deletions

View File

@ -1,56 +1,9 @@
"""Quick and dirty representation for OpenAPI specs."""
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union
from typing import List, Tuple
def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
"""Try to substitute $refs.
The goal is to get the complete docs for each endpoint in context for now.
In the few OpenAPI specs I studied, $refs referenced models
(or in OpenAPI terms, components) and could be nested. This code most
likely misses lots of cases.
"""
def _retrieve_ref_path(path: str, full_spec: dict) -> dict:
components = path.split("/")
if components[0] != "#":
raise RuntimeError(
"All $refs I've seen so far are uri fragments (start with hash)."
)
out = full_spec
for component in components[1:]:
out = out[component]
return out
def _dereference_refs(
obj: Union[dict, list], stop: bool = False
) -> Union[dict, list]:
if stop:
return obj
obj_out: Dict[str, Any] = {}
if isinstance(obj, dict):
for k, v in obj.items():
if k == "$ref":
# stop=True => don't dereference recursively.
return _dereference_refs(
_retrieve_ref_path(v, full_spec), stop=True
)
elif isinstance(v, list):
obj_out[k] = [_dereference_refs(el) for el in v]
elif isinstance(v, dict):
obj_out[k] = _dereference_refs(v)
else:
obj_out[k] = v
return obj_out
elif isinstance(obj, list):
return [_dereference_refs(el) for el in obj]
else:
return obj
return _dereference_refs(spec_obj)
from langchain.utils.json_schema import dereference_refs
@dataclass(frozen=True)
@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
if dereference:
endpoints = [
(name, description, dereference_refs(docs, spec))
(name, description, dereference_refs(docs, full_schema=spec))
for name, description, docs in endpoints
]

View File

@ -10,6 +10,7 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
from langchain.base_language import BaseLanguageModel
@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import (
from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema import BaseLLMOutputParser
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
PYTHON_TO_JSON_TYPES = {
"str": "string",
@ -148,14 +150,7 @@ def convert_to_openai_function(
if isinstance(function, dict):
return function
elif isinstance(function, type) and issubclass(function, BaseModel):
# Mypy error:
# "type" has no attribute "schema"
schema = function.schema() # type: ignore[attr-defined]
return {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
return cast(Dict, convert_pydantic_to_openai_function(function))
elif callable(function):
return convert_python_function_to_openai_function(function)

View File

@ -1,41 +1,21 @@
from typing import TypedDict
from langchain.tools import BaseTool, StructuredTool
class FunctionDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The parameters of the function."""
from langchain.tools import BaseTool
from langchain.utils.openai_functions import (
FunctionDescription,
convert_pydantic_to_openai_function,
)
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API."""
if isinstance(tool, StructuredTool):
schema_ = tool.args_schema.schema()
# Bug with required missing for structured tools.
required = schema_.get(
"required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND
if tool.args_schema:
return convert_pydantic_to_openai_function(
tool.args_schema, name=tool.name, description=tool.description
)
else:
return {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": schema_["properties"],
"required": required,
},
}
else:
if tool.args_schema:
parameters = tool.args_schema.schema()
else:
parameters = {
# This is a hack to get around the fact that some tools
# do not expose an args_schema, and expect an argument
# which is a string.
@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
},
"required": ["__arg1"],
"type": "object",
}
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
},
}

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Optional, TypeVar, Union, cast
def _retrieve_ref(path: str, schema: dict) -> dict:
components = path.split("/")
if components[0] != "#":
raise ValueError(
"ref paths are expected to be URI fragments, meaning they should start "
"with #."
)
out = schema
for component in components[1:]:
out = out[component]
return out
JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list])
def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE:
if isinstance(obj, dict):
obj_out = {}
for k, v in obj.items():
if k == "$ref":
ref = _retrieve_ref(v, full_schema)
obj_out[k] = _dereference_refs_helper(ref, full_schema)
elif isinstance(v, (list, dict)):
obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore
else:
obj_out[k] = v
return cast(JSON_LIKE, obj_out)
elif isinstance(obj, list):
return cast(
JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj]
)
else:
return obj
def dereference_refs(
schema_obj: dict, *, full_schema: Optional[dict] = None
) -> Union[dict, list]:
"""Try to substitute $refs in JSON Schema."""
full_schema = full_schema or schema_obj
return _dereference_refs_helper(schema_obj, full_schema)