This commit is contained in:
Bagatur 2023-08-30 13:37:39 -07:00
parent 7fa82900cb
commit 240cc289e6
4 changed files with 64 additions and 93 deletions

View File

@ -1,56 +1,9 @@
"""Quick and dirty representation for OpenAPI specs.""" """Quick and dirty representation for OpenAPI specs."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import List, Tuple
from langchain.utils.json_schema import dereference_refs
def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
"""Try to substitute $refs.
The goal is to get the complete docs for each endpoint in context for now.
In the few OpenAPI specs I studied, $refs referenced models
(or in OpenAPI terms, components) and could be nested. This code most
likely misses lots of cases.
"""
def _retrieve_ref_path(path: str, full_spec: dict) -> dict:
components = path.split("/")
if components[0] != "#":
raise RuntimeError(
"All $refs I've seen so far are uri fragments (start with hash)."
)
out = full_spec
for component in components[1:]:
out = out[component]
return out
def _dereference_refs(
obj: Union[dict, list], stop: bool = False
) -> Union[dict, list]:
if stop:
return obj
obj_out: Dict[str, Any] = {}
if isinstance(obj, dict):
for k, v in obj.items():
if k == "$ref":
# stop=True => don't dereference recursively.
return _dereference_refs(
_retrieve_ref_path(v, full_spec), stop=True
)
elif isinstance(v, list):
obj_out[k] = [_dereference_refs(el) for el in v]
elif isinstance(v, dict):
obj_out[k] = _dereference_refs(v)
else:
obj_out[k] = v
return obj_out
elif isinstance(obj, list):
return [_dereference_refs(el) for el in obj]
else:
return obj
return _dereference_refs(spec_obj)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS
# Note: probably want to do this post-retrieval, it blows up the size of the spec. # Note: probably want to do this post-retrieval, it blows up the size of the spec.
if dereference: if dereference:
endpoints = [ endpoints = [
(name, description, dereference_refs(docs, spec)) (name, description, dereference_refs(docs, full_schema=spec))
for name, description, docs in endpoints for name, description, docs in endpoints
] ]

View File

@ -10,6 +10,7 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
cast,
) )
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import (
from langchain.prompts import BasePromptTemplate from langchain.prompts import BasePromptTemplate
from langchain.pydantic_v1 import BaseModel from langchain.pydantic_v1 import BaseModel
from langchain.schema import BaseLLMOutputParser from langchain.schema import BaseLLMOutputParser
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
PYTHON_TO_JSON_TYPES = { PYTHON_TO_JSON_TYPES = {
"str": "string", "str": "string",
@ -148,14 +150,7 @@ def convert_to_openai_function(
if isinstance(function, dict): if isinstance(function, dict):
return function return function
elif isinstance(function, type) and issubclass(function, BaseModel): elif isinstance(function, type) and issubclass(function, BaseModel):
# Mypy error: return cast(Dict, convert_pydantic_to_openai_function(function))
# "type" has no attribute "schema"
schema = function.schema() # type: ignore[attr-defined]
return {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
elif callable(function): elif callable(function):
return convert_python_function_to_openai_function(function) return convert_python_function_to_openai_function(function)

View File

@ -1,41 +1,21 @@
from typing import TypedDict from langchain.tools import BaseTool
from langchain.utils.openai_functions import (
from langchain.tools import BaseTool, StructuredTool FunctionDescription,
convert_pydantic_to_openai_function,
)
class FunctionDescription(TypedDict):
"""Representation of a callable function to the OpenAI API."""
name: str
"""The name of the function."""
description: str
"""A description of the function."""
parameters: dict
"""The parameters of the function."""
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription: def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
"""Format tool into the OpenAI function API.""" """Format tool into the OpenAI function API."""
if isinstance(tool, StructuredTool): if tool.args_schema:
schema_ = tool.args_schema.schema() return convert_pydantic_to_openai_function(
# Bug with required missing for structured tools. tool.args_schema, name=tool.name, description=tool.description
required = schema_.get(
"required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND
) )
else:
return { return {
"name": tool.name, "name": tool.name,
"description": tool.description, "description": tool.description,
"parameters": { "parameters": {
"type": "object",
"properties": schema_["properties"],
"required": required,
},
}
else:
if tool.args_schema:
parameters = tool.args_schema.schema()
else:
parameters = {
# This is a hack to get around the fact that some tools # This is a hack to get around the fact that some tools
# do not expose an args_schema, and expect an argument # do not expose an args_schema, and expect an argument
# which is a string. # which is a string.
@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
}, },
"required": ["__arg1"], "required": ["__arg1"],
"type": "object", "type": "object",
} },
return {
"name": tool.name,
"description": tool.description,
"parameters": parameters,
} }

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Optional, TypeVar, Union, cast
def _retrieve_ref(path: str, schema: dict) -> dict:
components = path.split("/")
if components[0] != "#":
raise ValueError(
"ref paths are expected to be URI fragments, meaning they should start "
"with #."
)
out = schema
for component in components[1:]:
out = out[component]
return out
JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list])
def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE:
if isinstance(obj, dict):
obj_out = {}
for k, v in obj.items():
if k == "$ref":
ref = _retrieve_ref(v, full_schema)
obj_out[k] = _dereference_refs_helper(ref, full_schema)
elif isinstance(v, (list, dict)):
obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore
else:
obj_out[k] = v
return cast(JSON_LIKE, obj_out)
elif isinstance(obj, list):
return cast(
JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj]
)
else:
return obj
def dereference_refs(
schema_obj: dict, *, full_schema: Optional[dict] = None
) -> Union[dict, list]:
"""Try to substitute $refs in JSON Schema."""
full_schema = full_schema or schema_obj
return _dereference_refs_helper(schema_obj, full_schema)