mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 08:33:49 +00:00
commit
d43a36c32a
@ -1,56 +1,9 @@
|
||||
"""Quick and dirty representation for OpenAPI specs."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
|
||||
"""Try to substitute $refs.
|
||||
|
||||
The goal is to get the complete docs for each endpoint in context for now.
|
||||
|
||||
In the few OpenAPI specs I studied, $refs referenced models
|
||||
(or in OpenAPI terms, components) and could be nested. This code most
|
||||
likely misses lots of cases.
|
||||
"""
|
||||
|
||||
def _retrieve_ref_path(path: str, full_spec: dict) -> dict:
|
||||
components = path.split("/")
|
||||
if components[0] != "#":
|
||||
raise RuntimeError(
|
||||
"All $refs I've seen so far are uri fragments (start with hash)."
|
||||
)
|
||||
out = full_spec
|
||||
for component in components[1:]:
|
||||
out = out[component]
|
||||
return out
|
||||
|
||||
def _dereference_refs(
|
||||
obj: Union[dict, list], stop: bool = False
|
||||
) -> Union[dict, list]:
|
||||
if stop:
|
||||
return obj
|
||||
obj_out: Dict[str, Any] = {}
|
||||
if isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
if k == "$ref":
|
||||
# stop=True => don't dereference recursively.
|
||||
return _dereference_refs(
|
||||
_retrieve_ref_path(v, full_spec), stop=True
|
||||
)
|
||||
elif isinstance(v, list):
|
||||
obj_out[k] = [_dereference_refs(el) for el in v]
|
||||
elif isinstance(v, dict):
|
||||
obj_out[k] = _dereference_refs(v)
|
||||
else:
|
||||
obj_out[k] = v
|
||||
return obj_out
|
||||
elif isinstance(obj, list):
|
||||
return [_dereference_refs(el) for el in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
return _dereference_refs(spec_obj)
|
||||
from langchain.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -90,7 +43,7 @@ def reduce_openapi_spec(spec: dict, dereference: bool = True) -> ReducedOpenAPIS
|
||||
# Note: probably want to do this post-retrieval, it blows up the size of the spec.
|
||||
if dereference:
|
||||
endpoints = [
|
||||
(name, description, dereference_refs(docs, spec))
|
||||
(name, description, dereference_refs(docs, full_schema=spec))
|
||||
for name, description, docs in endpoints
|
||||
]
|
||||
|
||||
|
@ -10,6 +10,7 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
@ -22,6 +23,7 @@ from langchain.output_parsers.openai_functions import (
|
||||
from langchain.prompts import BasePromptTemplate
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.schema import BaseLLMOutputParser
|
||||
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||
|
||||
PYTHON_TO_JSON_TYPES = {
|
||||
"str": "string",
|
||||
@ -148,14 +150,7 @@ def convert_to_openai_function(
|
||||
if isinstance(function, dict):
|
||||
return function
|
||||
elif isinstance(function, type) and issubclass(function, BaseModel):
|
||||
# Mypy error:
|
||||
# "type" has no attribute "schema"
|
||||
schema = function.schema() # type: ignore[attr-defined]
|
||||
return {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
return cast(Dict, convert_pydantic_to_openai_function(function))
|
||||
elif callable(function):
|
||||
return convert_python_function_to_openai_function(function)
|
||||
|
||||
|
@ -1,41 +1,21 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
|
||||
|
||||
class FunctionDescription(TypedDict):
|
||||
"""Representation of a callable function to the OpenAI API."""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters of the function."""
|
||||
from langchain.tools import BaseTool
|
||||
from langchain.utils.openai_functions import (
|
||||
FunctionDescription,
|
||||
convert_pydantic_to_openai_function,
|
||||
)
|
||||
|
||||
|
||||
def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
"""Format tool into the OpenAI function API."""
|
||||
if isinstance(tool, StructuredTool):
|
||||
schema_ = tool.args_schema.schema()
|
||||
# Bug with required missing for structured tools.
|
||||
required = schema_.get(
|
||||
"required", sorted(schema_["properties"]) # Backup is a BUG WORKAROUND
|
||||
if tool.args_schema:
|
||||
return convert_pydantic_to_openai_function(
|
||||
tool.args_schema, name=tool.name, description=tool.description
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": schema_["properties"],
|
||||
"required": required,
|
||||
},
|
||||
}
|
||||
else:
|
||||
if tool.args_schema:
|
||||
parameters = tool.args_schema.schema()
|
||||
else:
|
||||
parameters = {
|
||||
# This is a hack to get around the fact that some tools
|
||||
# do not expose an args_schema, and expect an argument
|
||||
# which is a string.
|
||||
@ -46,10 +26,5 @@ def format_tool_to_openai_function(tool: BaseTool) -> FunctionDescription:
|
||||
},
|
||||
"required": ["__arg1"],
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
return {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": parameters,
|
||||
},
|
||||
}
|
||||
|
72
libs/langchain/langchain/utils/json_schema.py
Normal file
72
libs/langchain/langchain/utils/json_schema.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
|
||||
def _retrieve_ref(path: str, schema: dict) -> dict:
|
||||
components = path.split("/")
|
||||
if components[0] != "#":
|
||||
raise ValueError(
|
||||
"ref paths are expected to be URI fragments, meaning they should start "
|
||||
"with #."
|
||||
)
|
||||
out = schema
|
||||
for component in components[1:]:
|
||||
out = out[component]
|
||||
return deepcopy(out)
|
||||
|
||||
|
||||
def _dereference_refs_helper(
|
||||
obj: Any, full_schema: dict, skip_keys: Sequence[str]
|
||||
) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
obj_out = {}
|
||||
for k, v in obj.items():
|
||||
if k in skip_keys:
|
||||
obj_out[k] = v
|
||||
elif k == "$ref":
|
||||
ref = _retrieve_ref(v, full_schema)
|
||||
return _dereference_refs_helper(ref, full_schema, skip_keys)
|
||||
elif isinstance(v, (list, dict)):
|
||||
obj_out[k] = _dereference_refs_helper(v, full_schema, skip_keys)
|
||||
else:
|
||||
obj_out[k] = v
|
||||
return obj_out
|
||||
elif isinstance(obj, list):
|
||||
return [_dereference_refs_helper(el, full_schema, skip_keys) for el in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
def _infer_skip_keys(obj: Any, full_schema: dict) -> List[str]:
|
||||
keys = []
|
||||
if isinstance(obj, dict):
|
||||
for k, v in obj.items():
|
||||
if k == "$ref":
|
||||
ref = _retrieve_ref(v, full_schema)
|
||||
keys.append(v.split("/")[1])
|
||||
keys += _infer_skip_keys(ref, full_schema)
|
||||
elif isinstance(v, (list, dict)):
|
||||
keys += _infer_skip_keys(v, full_schema)
|
||||
elif isinstance(obj, list):
|
||||
for el in obj:
|
||||
keys += _infer_skip_keys(el, full_schema)
|
||||
return keys
|
||||
|
||||
|
||||
def dereference_refs(
|
||||
schema_obj: dict,
|
||||
*,
|
||||
full_schema: Optional[dict] = None,
|
||||
skip_keys: Optional[Sequence[str]] = None,
|
||||
) -> dict:
|
||||
"""Try to substitute $refs in JSON Schema."""
|
||||
|
||||
full_schema = full_schema or schema_obj
|
||||
skip_keys = (
|
||||
skip_keys
|
||||
if skip_keys is not None
|
||||
else _infer_skip_keys(schema_obj, full_schema)
|
||||
)
|
||||
return _dereference_refs_helper(schema_obj, full_schema, skip_keys)
|
30
libs/langchain/langchain/utils/openai_functions.py
Normal file
30
libs/langchain/langchain/utils/openai_functions.py
Normal file
@ -0,0 +1,30 @@
|
||||
from typing import Optional, Type, TypedDict
|
||||
|
||||
from langchain.pydantic_v1 import BaseModel
|
||||
from langchain.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
class FunctionDescription(TypedDict):
|
||||
"""Representation of a callable function to the OpenAI API."""
|
||||
|
||||
name: str
|
||||
"""The name of the function."""
|
||||
description: str
|
||||
"""A description of the function."""
|
||||
parameters: dict
|
||||
"""The parameters of the function."""
|
||||
|
||||
|
||||
def convert_pydantic_to_openai_function(
|
||||
model: Type[BaseModel],
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None
|
||||
) -> FunctionDescription:
|
||||
schema = dereference_refs(model.schema())
|
||||
schema.pop("definitions", None)
|
||||
return {
|
||||
"name": name or schema["title"],
|
||||
"description": description or schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
151
libs/langchain/tests/unit_tests/utils/test_json_schema.py
Normal file
151
libs/langchain/tests/unit_tests/utils/test_json_schema.py
Normal file
@ -0,0 +1,151 @@
|
||||
import pytest
|
||||
|
||||
from langchain.utils.json_schema import dereference_refs
|
||||
|
||||
|
||||
def test_dereference_refs_no_refs() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
actual = dereference_refs(schema)
|
||||
assert actual == schema
|
||||
|
||||
|
||||
def test_dereference_refs_one_ref() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"$ref": "#/$defs/name"},
|
||||
},
|
||||
"$defs": {"name": {"type": "string"}},
|
||||
}
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"type": "string"},
|
||||
},
|
||||
"$defs": {"name": {"type": "string"}},
|
||||
}
|
||||
actual = dereference_refs(schema)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_dereference_refs_multiple_refs() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"$ref": "#/$defs/name"},
|
||||
"other": {"$ref": "#/$defs/other"},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||
},
|
||||
}
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"type": "string"},
|
||||
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"other": {"type": "object", "properties": {"age": "int", "height": "int"}},
|
||||
},
|
||||
}
|
||||
actual = dereference_refs(schema)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_dereference_refs_nested_refs_skip() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"$ref": "#/$defs/info"},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
actual = dereference_refs(schema)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_dereference_refs_nested_refs_no_skip() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"$ref": "#/$defs/info"},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"$ref": "#/$defs/name"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
expected = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
"$defs": {
|
||||
"name": {"type": "string"},
|
||||
"info": {
|
||||
"type": "object",
|
||||
"properties": {"age": "int", "name": {"type": "string"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
actual = dereference_refs(schema, skip_keys=())
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_dereference_refs_missing_ref() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"$ref": "#/$defs/name"},
|
||||
},
|
||||
"$defs": {},
|
||||
}
|
||||
with pytest.raises(KeyError):
|
||||
dereference_refs(schema)
|
||||
|
||||
|
||||
def test_dereference_refs_remote_ref() -> None:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"first_name": {"$ref": "https://somewhere/else/name"},
|
||||
},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
dereference_refs(schema)
|
@ -0,0 +1,79 @@
|
||||
from langchain.pydantic_v1 import BaseModel, Field
|
||||
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
|
||||
|
||||
|
||||
def test_convert_pydantic_to_openai_function() -> None:
|
||||
class Data(BaseModel):
|
||||
"""The data to return."""
|
||||
|
||||
key: str = Field(..., description="API key")
|
||||
days: int = Field(default=0, description="Number of days to forecast")
|
||||
|
||||
actual = convert_pydantic_to_openai_function(Data)
|
||||
expected = {
|
||||
"name": "Data",
|
||||
"description": "The data to return.",
|
||||
"parameters": {
|
||||
"title": "Data",
|
||||
"description": "The data to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {"title": "Key", "description": "API key", "type": "string"},
|
||||
"days": {
|
||||
"title": "Days",
|
||||
"description": "Number of days to forecast",
|
||||
"default": 0,
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": ["key"],
|
||||
},
|
||||
}
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_convert_pydantic_to_openai_function_nested() -> None:
|
||||
class Data(BaseModel):
|
||||
"""The data to return."""
|
||||
|
||||
key: str = Field(..., description="API key")
|
||||
days: int = Field(default=0, description="Number of days to forecast")
|
||||
|
||||
class Model(BaseModel):
|
||||
"""The model to return."""
|
||||
|
||||
data: Data
|
||||
|
||||
actual = convert_pydantic_to_openai_function(Model)
|
||||
expected = {
|
||||
"name": "Model",
|
||||
"description": "The model to return.",
|
||||
"parameters": {
|
||||
"title": "Model",
|
||||
"description": "The model to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"title": "Data",
|
||||
"description": "The data to return.",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"title": "Key",
|
||||
"description": "API key",
|
||||
"type": "string",
|
||||
},
|
||||
"days": {
|
||||
"title": "Days",
|
||||
"description": "Number of days to forecast",
|
||||
"default": 0,
|
||||
"type": "integer",
|
||||
},
|
||||
},
|
||||
"required": ["key"],
|
||||
}
|
||||
},
|
||||
"required": ["data"],
|
||||
},
|
||||
}
|
||||
assert actual == expected
|
Loading…
Reference in New Issue
Block a user