mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
wip
This commit is contained in:
parent
7fa82900cb
commit
240cc289e6
@ -1,56 +1,9 @@
|
|||||||
"""Quick and dirty representation for OpenAPI specs."""
|
"""Quick and dirty representation for OpenAPI specs."""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from langchain.utils.json_schema import dereference_refs
|
||||||
def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
|
|
||||||
"""Try to substitute $refs.
|
|
||||||
|
|
||||||
The goal is to get the complete docs for each endpoint in context for now.
|
|
||||||
|
|
||||||
In the few OpenAPI specs I studied, $refs referenced models
|
|
||||||
(or in OpenAPI terms, components) and could be nested. This code most
|
|
||||||
likely misses lots of cases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _retrieve_ref_path(path: str, full_spec: dict) -> dict:
|
|
||||||
components = path.split("/")
|
|
||||||
if components[0] != "#":
|
|
||||||
raise RuntimeError(
|
|
||||||
"All $refs I've seen so far are uri fragments (start with hash)."
|
|
||||||
)
|
|
||||||
out = full_spec
|
|
||||||
for component in components[1:]:
|
|
||||||
out = out[component]
|
|
||||||
return out
|
|
||||||
|
|
||||||
def _dereference_refs(
|
|
||||||
obj: Union[dict, list], stop: bool = False
|
|
||||||
) -> Union[dict, list]:
|
|
||||||
if stop:
|
|
||||||
return obj
|
|
||||||
obj_out: Dict[str, Any] = {}
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
for k, v in obj.items():
|
|
||||||
if k == "$ref":
|
|
||||||
# stop=True => don't dereference recursively.
|
|
||||||
return _dereference_refs(
|
|
||||||
_retrieve_ref_path(v, full_spec), stop=True
|
|
||||||
)
|
|
||||||
elif isinstance(v, list):
|
|
||||||
obj_out[k] = [_dereference_refs(el) for el in v]
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
obj_out[k] = _dereference_refs(v)
|
|
||||||
else:
|
|
||||||
obj_out[k] = v
|
|
||||||
return obj_out
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [_dereference_refs(el) for el in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
return _dereference_refs(spec_obj)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS
|
|||||||
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
|
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
|
||||||
if dereference:
|
if dereference:
|
||||||
endpoints = [
|
endpoints = [
|
||||||
(name, description, dereference_refs(docs, spec))
|
(name, description, dereference_refs(docs, full_schema=spec))
|
||||||
for name, description, docs in endpoints
|
for name, description, docs in endpoints
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
from langchain.base_language import BaseLanguageModel
|
||||||
@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import (
|
|||||||
from langchain.prompts import BasePromptTemplate
|
from langchain.prompts import BasePromptTemplate
|
||||||
from langchain.pydantic_v1 import BaseModel
|
from langchain.pydantic_v1 import BaseModel
|
||||||
from langchain.schema import BaseLLMOutputParser
|
from langchain.schema import BaseLLMOutputParser
|
||||||
|
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||||
|
|
||||||
PYTHON_TO_JSON_TYPES = {
|
PYTHON_TO_JSON_TYPES = {
|
||||||
"str": "string",
|
"str": "string",
|
||||||
@ -148,14 +150,7 @@ def convert_to_openai_function(
|
|||||||
if isinstance(function, dict):
|
if isinstance(function, dict):
|
||||||
return function
|
return function
|
||||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||||
# Mypy error:
|
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||||
# "type" has no attribute "schema"
|
|
||||||
schema = function.schema() # type: ignore[attr-defined]
|
|
||||||
return {
|
|
||||||
"name": schema["title"],
|
|
||||||
"description": schema["description"],
|
|
||||||
"parameters": schema,
|
|
||||||
}
|
|
||||||
elif callable(function):
|
elif callable(function):
|
||||||
return convert_python_function_to_openai_function(function)
|
return convert_python_function_to_openai_function(function)
|
||||||
|
|
||||||
|
@ -1,41 +1,21 @@
|
|||||||
from typing import TypedDict
|
from langchain.tools import BaseTool
|
||||||
|
from langchain.utils.openai_functions import (
|
||||||
from langchain.tools import BaseTool, StructuredTool
|
FunctionDescription,
|
||||||
|
convert_pydantic_to_openai_function,
|
||||||
|
)
|
||||||
class FunctionDescription(TypedDict):
|
|
||||||
"""Representation of a callable function to the OpenAI API."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
"""The name of the function."""
|
|
||||||
description: str
|
|
||||||
"""A description of the function."""
|
|
||||||
parameters: dict
|
|
||||||
"""The parameters of the function."""
|
|
||||||
|
|
||||||
|
|
||||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||||
"""Format tool into the OpenAI function API."""
|
"""Format tool into the OpenAI function API."""
|
||||||
if isinstance(tool, StructuredTool):
|
if tool.args_schema:
|
||||||
schema_ = tool.args_schema.schema()
|
return convert_pydantic_to_openai_function(
|
||||||
# Bug with required missing for structured tools.
|
tool.args_schema, name=tool.name, description=tool.description
|
||||||
required = schema_.get(
|
|
||||||
"required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
return {
|
return {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
|
||||||
"properties": schema_["properties"],
|
|
||||||
"required": required,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if tool.args_schema:
|
|
||||||
parameters = tool.args_schema.schema()
|
|
||||||
else:
|
|
||||||
parameters = {
|
|
||||||
# This is a hack to get around the fact that some tools
|
# This is a hack to get around the fact that some tools
|
||||||
# do not expose an args_schema, and expect an argument
|
# do not expose an args_schema, and expect an argument
|
||||||
# which is a string.
|
# which is a string.
|
||||||
@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
|||||||
},
|
},
|
||||||
"required": ["__arg1"],
|
"required": ["__arg1"],
|
||||||
"type": "object",
|
"type": "object",
|
||||||
}
|
},
|
||||||
|
|
||||||
return {
|
|
||||||
"name": tool.name,
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": parameters,
|
|
||||||
}
|
}
|
||||||
|
48
libs/langchain/langchain/utils/json_schema.py
Normal file
48
libs/langchain/langchain/utils/json_schema.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Optional, TypeVar, Union, cast
|
||||||
|
|
||||||
|
|
||||||
|
def _retrieve_ref(path: str, schema: dict) -> dict:
|
||||||
|
components = path.split("/")
|
||||||
|
if components[0] != "#":
|
||||||
|
raise ValueError(
|
||||||
|
"ref paths are expected to be URI fragments, meaning they should start "
|
||||||
|
"with #."
|
||||||
|
)
|
||||||
|
out = schema
|
||||||
|
for component in components[1:]:
|
||||||
|
out = out[component]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
JSON_LIKE = TypeVar("JSON_LIKE", bound=Union[dict, list])
|
||||||
|
|
||||||
|
|
||||||
|
def _dereference_refs_helper(obj: JSON_LIKE, full_schema: dict) -> JSON_LIKE:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
obj_out = {}
|
||||||
|
for k, v in obj.items():
|
||||||
|
if k == "$ref":
|
||||||
|
ref = _retrieve_ref(v, full_schema)
|
||||||
|
obj_out[k] = _dereference_refs_helper(ref, full_schema)
|
||||||
|
elif isinstance(v, (list, dict)):
|
||||||
|
obj_out[k] = _dereference_refs_helper(v, full_schema) # type: ignore
|
||||||
|
else:
|
||||||
|
obj_out[k] = v
|
||||||
|
return cast(JSON_LIKE, obj_out)
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return cast(
|
||||||
|
JSON_LIKE, [_dereference_refs_helper(el, full_schema) for el in obj]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
def dereference_refs(
|
||||||
|
schema_obj: dict, *, full_schema: Optional[dict] = None
|
||||||
|
) -> Union[dict, list]:
|
||||||
|
"""Try to substitute $refs in JSON Schema."""
|
||||||
|
|
||||||
|
full_schema = full_schema or schema_obj
|
||||||
|
return _dereference_refs_helper(schema_obj, full_schema)
|
Loading…
Reference in New Issue
Block a user