mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-29 07:19:59 +00:00
- **Description:** This PR is intended to improve the generation of payloads for OpenAI functions when converting from an OpenAPI spec file. The solution is to recursively resolve `$refs`. Currently when converting OpenAPI specs into OpenAI functions using `openapi_spec_to_openai_fn`, if the schemas have nested references, the generated functions contain `$ref` that causes the LLM to generate payloads with an incorrect schema. For example, for the for OpenAPI spec: ``` text = """ { "openapi": "3.0.3", "info": { "title": "Swagger Petstore - OpenAPI 3.0", "termsOfService": "http://swagger.io/terms/", "contact": { "email": "apiteam@swagger.io" }, "license": { "name": "Apache 2.0", "url": "http://www.apache.org/licenses/LICENSE-2.0.html" }, "version": "1.0.11" }, "externalDocs": { "description": "Find out more about Swagger", "url": "http://swagger.io" }, "servers": [ { "url": "https://petstore3.swagger.io/api/v3" } ], "tags": [ { "name": "pet", "description": "Everything about your Pets", "externalDocs": { "description": "Find out more", "url": "http://swagger.io" } }, { "name": "store", "description": "Access to Petstore orders", "externalDocs": { "description": "Find out more about our store", "url": "http://swagger.io" } }, { "name": "user", "description": "Operations about user" } ], "paths": { "/pet": { "post": { "tags": [ "pet" ], "summary": "Add a new pet to the store", "description": "Add a new pet to the store", "operationId": "addPet", "requestBody": { "description": "Create a new pet in the store", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/Pet" } } }, "required": true }, "responses": { "200": { "description": "Successful operation", "content": { "application/json": { "schema": { "$ref": "#/components/schemas/Pet" } } } } } } } }, "components": { "schemas": { "Tag": { "type": "object", "properties": { "id": { "type": "integer", "format": "int64" }, "model_type": { "type": "number" } } }, "Category": { "type": "object", "required": [ "model", "year", "age" ], "properties": { "year": { "type": "integer", "format": "int64", "example": 1 }, "model": { "type": "string", "example": "Ford" }, "age": { "type": "integer", "example": 42 } } }, "Pet": { "required": [ "name" ], "type": "object", "properties": { "id": { "type": "integer", "format": "int64", "example": 10 }, "name": { "type": "string", "example": "doggie" }, "category": { "$ref": "#/components/schemas/Category" }, "tags": { "type": "array", "items": { "$ref": "#/components/schemas/Tag" } }, "status": { "type": "string", "description": "pet status in the store", "enum": [ "available", "pending", "sold" ] } } } } } } """ ``` Executing: ``` spec = OpenAPISpec.from_text(text) pet_openai_functions, pet_callables = openapi_spec_to_openai_fn(spec) response = model.invoke("Create a pet named Scott", functions=pet_openai_functions) ``` `pet_open_functions` contains unresolved `$refs`: ``` [ { "name": "addPet", "description": "Add a new pet to the store", "parameters": { "type": "object", "properties": { "json": { "properties": { "id": { "type": "integer", "schema_format": "int64", "example": 10 }, "name": { "type": "string", "example": "doggie" }, "category": { "ref": "#/components/schemas/Category" }, "tags": { "items": { "ref": "#/components/schemas/Tag" }, "type": "array" }, "status": { "type": "string", "enum": [ "available", "pending", "sold" ], "description": "pet status in the store" } }, "type": "object", "required": [ "name", "photoUrls" ] } } } } ] ``` and the generated JSON has an incorrect schema (e.g. category is filled with `id` and `name` instead of `model`, `year` and `age`: ``` { "id": 1, "name": "Scott", "category": { "id": 1, "name": "Dogs" }, "tags": [ { "id": 1, "name": "tag1" } ], "status": "available" } ``` With this change, the generated JSON by the LLM becomes, `pet_openai_functions` becomes: ``` [ { "name": "addPet", "description": "Add a new pet to the store", "parameters": { "type": "object", "properties": { "json": { "properties": { "id": { "type": "integer", "schema_format": "int64", "example": 10 }, "name": { "type": "string", "example": "doggie" }, "category": { "properties": { "year": { "type": "integer", "schema_format": "int64", "example": 1 }, "model": { "type": "string", "example": "Ford" }, "age": { "type": "integer", "example": 42 } }, "type": "object", "required": [ "model", "year", "age" ] }, "tags": { "items": { "properties": { "id": { "type": "integer", "schema_format": "int64" }, "model_type": { "type": "number" } }, "type": "object" }, "type": "array" }, "status": { "type": "string", "enum": [ "available", "pending", "sold" ], "description": "pet status in the store" } }, "type": "object", "required": [ "name" ] } } } } ] ``` and the JSON generated by the LLM is: ``` { "id": 1, "name": "Scott", "category": { "year": 2022, "model": "Dog", "age": 42 }, "tags": [ { "id": 1, "model_type": 1 } ], "status": "available" } ``` which has the intended schema. - **Twitter handle:**: @brunoalvisio --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
336 lines
11 KiB
Python
336 lines
11 KiB
Python
"""Utility functions for parsing an OpenAPI spec."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import json
|
|
import logging
|
|
import re
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
|
|
|
import requests
|
|
import yaml
|
|
from langchain_core.pydantic_v1 import ValidationError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HTTPVerb(str, Enum):
|
|
"""Enumerator of the HTTP verbs."""
|
|
|
|
GET = "get"
|
|
PUT = "put"
|
|
POST = "post"
|
|
DELETE = "delete"
|
|
OPTIONS = "options"
|
|
HEAD = "head"
|
|
PATCH = "patch"
|
|
TRACE = "trace"
|
|
|
|
@classmethod
|
|
def from_str(cls, verb: str) -> HTTPVerb:
|
|
"""Parse an HTTP verb."""
|
|
try:
|
|
return cls(verb)
|
|
except ValueError:
|
|
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from openapi_pydantic import (
|
|
Components,
|
|
Operation,
|
|
Parameter,
|
|
PathItem,
|
|
Paths,
|
|
Reference,
|
|
RequestBody,
|
|
Schema,
|
|
)
|
|
|
|
try:
|
|
from openapi_pydantic import OpenAPI
|
|
except ImportError:
|
|
OpenAPI = object # type: ignore
|
|
|
|
|
|
class OpenAPISpec(OpenAPI):
|
|
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
|
|
|
openapi: str = "3.1.0" # overriding overly restrictive type from parent class
|
|
|
|
@property
|
|
def _paths_strict(self) -> Paths:
|
|
if not self.paths:
|
|
raise ValueError("No paths found in spec")
|
|
return self.paths
|
|
|
|
def _get_path_strict(self, path: str) -> PathItem:
|
|
path_item = self._paths_strict.get(path)
|
|
if not path_item:
|
|
raise ValueError(f"No path found for {path}")
|
|
return path_item
|
|
|
|
@property
|
|
def _components_strict(self) -> Components:
|
|
"""Get components or err."""
|
|
if self.components is None:
|
|
raise ValueError("No components found in spec. ")
|
|
return self.components
|
|
|
|
@property
|
|
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
|
|
"""Get parameters or err."""
|
|
parameters = self._components_strict.parameters
|
|
if parameters is None:
|
|
raise ValueError("No parameters found in spec. ")
|
|
return parameters
|
|
|
|
@property
|
|
def _schemas_strict(self) -> Dict[str, Schema]:
|
|
"""Get the dictionary of schemas or err."""
|
|
schemas = self._components_strict.schemas
|
|
if schemas is None:
|
|
raise ValueError("No schemas found in spec. ")
|
|
return schemas
|
|
|
|
@property
|
|
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
|
|
"""Get the request body or err."""
|
|
request_bodies = self._components_strict.requestBodies
|
|
if request_bodies is None:
|
|
raise ValueError("No request body found in spec. ")
|
|
return request_bodies
|
|
|
|
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
|
|
"""Get a parameter (or nested reference) or err."""
|
|
ref_name = ref.ref.split("/")[-1]
|
|
parameters = self._parameters_strict
|
|
if ref_name not in parameters:
|
|
raise ValueError(f"No parameter found for {ref_name}")
|
|
return parameters[ref_name]
|
|
|
|
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
|
|
"""Get the root reference or err."""
|
|
from openapi_pydantic import Reference
|
|
|
|
parameter = self._get_referenced_parameter(ref)
|
|
while isinstance(parameter, Reference):
|
|
parameter = self._get_referenced_parameter(parameter)
|
|
return parameter
|
|
|
|
def get_referenced_schema(self, ref: Reference) -> Schema:
|
|
"""Get a schema (or nested reference) or err."""
|
|
ref_name = ref.ref.split("/")[-1]
|
|
schemas = self._schemas_strict
|
|
if ref_name not in schemas:
|
|
raise ValueError(f"No schema found for {ref_name}")
|
|
return schemas[ref_name]
|
|
|
|
def get_schema(
|
|
self,
|
|
schema: Union[Reference, Schema],
|
|
depth: int = 0,
|
|
max_depth: Optional[int] = None,
|
|
) -> Schema:
|
|
if max_depth is not None and depth >= max_depth:
|
|
raise RecursionError(
|
|
f"Max depth of {max_depth} has been exceeded when resolving references."
|
|
)
|
|
|
|
from openapi_pydantic import Reference
|
|
|
|
if isinstance(schema, Reference):
|
|
schema = self.get_referenced_schema(schema)
|
|
|
|
# TODO: Resolve references on all fields of Schema ?
|
|
# (e.g. patternProperties, etc...)
|
|
if schema.properties is not None:
|
|
for p_name, p in schema.properties.items():
|
|
schema.properties[p_name] = self.get_schema(p, depth + 1, max_depth)
|
|
|
|
if schema.items is not None:
|
|
schema.items = self.get_schema(schema.items, depth + 1, max_depth)
|
|
|
|
return schema
|
|
|
|
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
|
|
"""Get the root reference or err."""
|
|
from openapi_pydantic import Reference
|
|
|
|
schema = self.get_referenced_schema(ref)
|
|
while isinstance(schema, Reference):
|
|
schema = self.get_referenced_schema(schema)
|
|
return schema
|
|
|
|
def _get_referenced_request_body(
|
|
self, ref: Reference
|
|
) -> Optional[Union[Reference, RequestBody]]:
|
|
"""Get a request body (or nested reference) or err."""
|
|
ref_name = ref.ref.split("/")[-1]
|
|
request_bodies = self._request_bodies_strict
|
|
if ref_name not in request_bodies:
|
|
raise ValueError(f"No request body found for {ref_name}")
|
|
return request_bodies[ref_name]
|
|
|
|
def _get_root_referenced_request_body(
|
|
self, ref: Reference
|
|
) -> Optional[RequestBody]:
|
|
"""Get the root request Body or err."""
|
|
from openapi_pydantic import Reference
|
|
|
|
request_body = self._get_referenced_request_body(ref)
|
|
while isinstance(request_body, Reference):
|
|
request_body = self._get_referenced_request_body(request_body)
|
|
return request_body
|
|
|
|
@staticmethod
|
|
def _alert_unsupported_spec(obj: dict) -> None:
|
|
"""Alert if the spec is not supported."""
|
|
warning_message = (
|
|
" This may result in degraded performance."
|
|
+ " Convert your OpenAPI spec to 3.1.* spec"
|
|
+ " for better support."
|
|
)
|
|
swagger_version = obj.get("swagger")
|
|
openapi_version = obj.get("openapi")
|
|
if isinstance(openapi_version, str):
|
|
if openapi_version != "3.1.0":
|
|
logger.warning(
|
|
f"Attempting to load an OpenAPI {openapi_version}"
|
|
f" spec. {warning_message}"
|
|
)
|
|
else:
|
|
pass
|
|
elif isinstance(swagger_version, str):
|
|
logger.warning(
|
|
f"Attempting to load a Swagger {swagger_version}"
|
|
f" spec. {warning_message}"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Attempting to load an unsupported spec:"
|
|
f"\n\n{obj}\n{warning_message}"
|
|
)
|
|
|
|
@classmethod
|
|
def parse_obj(cls, obj: dict) -> OpenAPISpec:
|
|
try:
|
|
cls._alert_unsupported_spec(obj)
|
|
return super().parse_obj(obj)
|
|
except ValidationError as e:
|
|
# We are handling possibly misconfigured specs and
|
|
# want to do a best-effort job to get a reasonable interface out of it.
|
|
new_obj = copy.deepcopy(obj)
|
|
for error in e.errors():
|
|
keys = error["loc"]
|
|
item = new_obj
|
|
for key in keys[:-1]:
|
|
item = item[key]
|
|
item.pop(keys[-1], None)
|
|
return cls.parse_obj(new_obj)
|
|
|
|
@classmethod
|
|
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
|
|
"""Get an OpenAPI spec from a dict."""
|
|
return cls.parse_obj(spec_dict)
|
|
|
|
@classmethod
|
|
def from_text(cls, text: str) -> OpenAPISpec:
|
|
"""Get an OpenAPI spec from a text."""
|
|
try:
|
|
spec_dict = json.loads(text)
|
|
except json.JSONDecodeError:
|
|
spec_dict = yaml.safe_load(text)
|
|
return cls.from_spec_dict(spec_dict)
|
|
|
|
@classmethod
|
|
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
|
|
"""Get an OpenAPI spec from a file path."""
|
|
path_ = path if isinstance(path, Path) else Path(path)
|
|
if not path_.exists():
|
|
raise FileNotFoundError(f"{path} does not exist")
|
|
with path_.open("r") as f:
|
|
return cls.from_text(f.read())
|
|
|
|
@classmethod
|
|
def from_url(cls, url: str) -> OpenAPISpec:
|
|
"""Get an OpenAPI spec from a URL."""
|
|
response = requests.get(url)
|
|
return cls.from_text(response.text)
|
|
|
|
@property
|
|
def base_url(self) -> str:
|
|
"""Get the base url."""
|
|
return self.servers[0].url
|
|
|
|
def get_methods_for_path(self, path: str) -> List[str]:
|
|
"""Return a list of valid methods for the specified path."""
|
|
from openapi_pydantic import Operation
|
|
|
|
path_item = self._get_path_strict(path)
|
|
results = []
|
|
for method in HTTPVerb:
|
|
operation = getattr(path_item, method.value, None)
|
|
if isinstance(operation, Operation):
|
|
results.append(method.value)
|
|
return results
|
|
|
|
def get_parameters_for_path(self, path: str) -> List[Parameter]:
|
|
from openapi_pydantic import Reference
|
|
|
|
path_item = self._get_path_strict(path)
|
|
parameters = []
|
|
if not path_item.parameters:
|
|
return []
|
|
for parameter in path_item.parameters:
|
|
if isinstance(parameter, Reference):
|
|
parameter = self._get_root_referenced_parameter(parameter)
|
|
parameters.append(parameter)
|
|
return parameters
|
|
|
|
def get_operation(self, path: str, method: str) -> Operation:
|
|
"""Get the operation object for a given path and HTTP method."""
|
|
from openapi_pydantic import Operation
|
|
|
|
path_item = self._get_path_strict(path)
|
|
operation_obj = getattr(path_item, method, None)
|
|
if not isinstance(operation_obj, Operation):
|
|
raise ValueError(f"No {method} method found for {path}")
|
|
return operation_obj
|
|
|
|
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
|
|
"""Get the components for a given operation."""
|
|
from openapi_pydantic import Reference
|
|
|
|
parameters = []
|
|
if operation.parameters:
|
|
for parameter in operation.parameters:
|
|
if isinstance(parameter, Reference):
|
|
parameter = self._get_root_referenced_parameter(parameter)
|
|
parameters.append(parameter)
|
|
return parameters
|
|
|
|
def get_request_body_for_operation(
|
|
self, operation: Operation
|
|
) -> Optional[RequestBody]:
|
|
"""Get the request body for a given operation."""
|
|
from openapi_pydantic import Reference
|
|
|
|
request_body = operation.requestBody
|
|
if isinstance(request_body, Reference):
|
|
request_body = self._get_root_referenced_request_body(request_body)
|
|
return request_body
|
|
|
|
@staticmethod
|
|
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
|
|
"""Get a cleaned operation id from an operation id."""
|
|
operation_id = operation.operationId
|
|
if operation_id is None:
|
|
# Replace all punctuation of any kind with underscore
|
|
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
|
operation_id = f"{path}_{method}"
|
|
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|