Wrap OpenAPI features in conditionals for pydantic v2 compatibility (#9205)

Wrap OpenAPI in conditionals for pydantic v2 compatibility.
This commit is contained in:
Eugene Yurtsev 2023-08-14 13:40:58 -04:00 committed by GitHub
parent 89be10f6b4
commit 4f1feaca83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 754 additions and 673 deletions

View File

@ -27,6 +27,12 @@ if "pydantic_v1" not in sys.modules:
# and may run prior to langchain core package. # and may run prior to langchain core package.
sys.modules["pydantic_v1"] = pydantic_v1 sys.modules["pydantic_v1"] = pydantic_v1
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache from langchain.cache import BaseCache
from langchain.chains import ( from langchain.chains import (

File diff suppressed because it is too large Load Diff

View File

@ -1,2 +1,4 @@
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat.""" """Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401 from langchain.utilities.openapi import HTTPVerb, OpenAPISpec
__all__ = ["HTTPVerb", "OpenAPISpec"]

View File

@ -2,9 +2,8 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
from pydantic_v1 import Extra, root_validator from pydantic_v1 import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env

View File

@ -1,4 +1,6 @@
"""Utility functions for parsing an OpenAPI spec.""" """Utility functions for parsing an OpenAPI spec."""
from __future__ import annotations
import copy import copy
import json import json
import logging import logging
@ -9,19 +11,10 @@ from typing import Dict, List, Optional, Union
import requests import requests
import yaml import yaml
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
from pydantic_v1 import ValidationError from pydantic_v1 import ValidationError
from langchain import _PYDANTIC_MAJOR_VERSION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +31,7 @@ class HTTPVerb(str, Enum):
TRACE = "trace" TRACE = "trace"
@classmethod @classmethod
def from_str(cls, verb: str) -> "HTTPVerb": def from_str(cls, verb: str) -> HTTPVerb:
"""Parse an HTTP verb.""" """Parse an HTTP verb."""
try: try:
return cls(verb) return cls(verb)
@ -46,240 +39,265 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}") raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI): if _PYDANTIC_MAJOR_VERSION == 1:
"""OpenAPI Model that removes misformatted parts of the spec.""" from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
@property class OpenAPISpec(OpenAPI):
def _paths_strict(self) -> Paths: """OpenAPI Model that removes mis-formatted parts of the spec."""
if not self.paths:
raise ValueError("No paths found in spec")
return self.paths
def _get_path_strict(self, path: str) -> PathItem: @property
path_item = self._paths_strict.get(path) def _paths_strict(self) -> Paths:
if not path_item: if not self.paths:
raise ValueError(f"No path found for {path}") raise ValueError("No paths found in spec")
return path_item return self.paths
@property def _get_path_strict(self, path: str) -> PathItem:
def _components_strict(self) -> Components: path_item = self._paths_strict.get(path)
"""Get components or err.""" if not path_item:
if self.components is None: raise ValueError(f"No path found for {path}")
raise ValueError("No components found in spec. ") return path_item
return self.components
@property @property
def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]: def _components_strict(self) -> Components:
"""Get parameters or err.""" """Get components or err."""
parameters = self._components_strict.parameters if self.components is None:
if parameters is None: raise ValueError("No components found in spec. ")
raise ValueError("No parameters found in spec. ") return self.components
return parameters
@property @property
def _schemas_strict(self) -> Dict[str, Schema]: def _parameters_strict(self) -> Dict[str, Union[Parameter, Reference]]:
"""Get the dictionary of schemas or err.""" """Get parameters or err."""
schemas = self._components_strict.schemas parameters = self._components_strict.parameters
if schemas is None: if parameters is None:
raise ValueError("No schemas found in spec. ") raise ValueError("No parameters found in spec. ")
return schemas return parameters
@property @property
def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]: def _schemas_strict(self) -> Dict[str, Schema]:
"""Get the request body or err.""" """Get the dictionary of schemas or err."""
request_bodies = self._components_strict.requestBodies schemas = self._components_strict.schemas
if request_bodies is None: if schemas is None:
raise ValueError("No request body found in spec. ") raise ValueError("No schemas found in spec. ")
return request_bodies return schemas
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]: @property
"""Get a parameter (or nested reference) or err.""" def _request_bodies_strict(self) -> Dict[str, Union[RequestBody, Reference]]:
ref_name = ref.ref.split("/")[-1] """Get the request body or err."""
parameters = self._parameters_strict request_bodies = self._components_strict.requestBodies
if ref_name not in parameters: if request_bodies is None:
raise ValueError(f"No parameter found for {ref_name}") raise ValueError("No request body found in spec. ")
return parameters[ref_name] return request_bodies
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter: def _get_referenced_parameter(
"""Get the root reference or err.""" self, ref: Reference
parameter = self._get_referenced_parameter(ref) ) -> Union[Parameter, Reference]:
while isinstance(parameter, Reference): """Get a parameter (or nested reference) or err."""
parameter = self._get_referenced_parameter(parameter) ref_name = ref.ref.split("/")[-1]
return parameter parameters = self._parameters_strict
if ref_name not in parameters:
raise ValueError(f"No parameter found for {ref_name}")
return parameters[ref_name]
def get_referenced_schema(self, ref: Reference) -> Schema: def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get a schema (or nested reference) or err.""" """Get the root reference or err."""
ref_name = ref.ref.split("/")[-1] parameter = self._get_referenced_parameter(ref)
schemas = self._schemas_strict while isinstance(parameter, Reference):
if ref_name not in schemas: parameter = self._get_referenced_parameter(parameter)
raise ValueError(f"No schema found for {ref_name}") return parameter
return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema: def get_referenced_schema(self, ref: Reference) -> Schema:
if isinstance(schema, Reference): """Get a schema (or nested reference) or err."""
return self.get_referenced_schema(schema) ref_name = ref.ref.split("/")[-1]
return schema schemas = self._schemas_strict
if ref_name not in schemas:
raise ValueError(f"No schema found for {ref_name}")
return schemas[ref_name]
def _get_root_referenced_schema(self, ref: Reference) -> Schema: def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
"""Get the root reference or err.""" if isinstance(schema, Reference):
schema = self.get_referenced_schema(ref) return self.get_referenced_schema(schema)
while isinstance(schema, Reference): return schema
schema = self.get_referenced_schema(schema)
return schema
def _get_referenced_request_body( def _get_root_referenced_schema(self, ref: Reference) -> Schema:
self, ref: Reference """Get the root reference or err."""
) -> Optional[Union[Reference, RequestBody]]: schema = self.get_referenced_schema(ref)
"""Get a request body (or nested reference) or err.""" while isinstance(schema, Reference):
ref_name = ref.ref.split("/")[-1] schema = self.get_referenced_schema(schema)
request_bodies = self._request_bodies_strict return schema
if ref_name not in request_bodies:
raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
def _get_root_referenced_request_body( def _get_referenced_request_body(
self, ref: Reference self, ref: Reference
) -> Optional[RequestBody]: ) -> Optional[Union[Reference, RequestBody]]:
"""Get the root request Body or err.""" """Get a request body (or nested reference) or err."""
request_body = self._get_referenced_request_body(ref) ref_name = ref.ref.split("/")[-1]
while isinstance(request_body, Reference): request_bodies = self._request_bodies_strict
request_body = self._get_referenced_request_body(request_body) if ref_name not in request_bodies:
return request_body raise ValueError(f"No request body found for {ref_name}")
return request_bodies[ref_name]
@staticmethod def _get_root_referenced_request_body(
def _alert_unsupported_spec(obj: dict) -> None: self, ref: Reference
"""Alert if the spec is not supported.""" ) -> Optional[RequestBody]:
warning_message = ( """Get the root request Body or err."""
" This may result in degraded performance." request_body = self._get_referenced_request_body(ref)
+ " Convert your OpenAPI spec to 3.1.* spec" while isinstance(request_body, Reference):
+ " for better support." request_body = self._get_referenced_request_body(request_body)
) return request_body
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi") @staticmethod
if isinstance(openapi_version, str): def _alert_unsupported_spec(obj: dict) -> None:
if openapi_version != "3.1.0": """Alert if the spec is not supported."""
warning_message = (
" This may result in degraded performance."
+ " Convert your OpenAPI spec to 3.1.* spec"
+ " for better support."
)
swagger_version = obj.get("swagger")
openapi_version = obj.get("openapi")
if isinstance(openapi_version, str):
if openapi_version != "3.1.0":
logger.warning(
f"Attempting to load an OpenAPI {openapi_version}"
f" spec. {warning_message}"
)
else:
pass
elif isinstance(swagger_version, str):
logger.warning( logger.warning(
f"Attempting to load an OpenAPI {openapi_version}" f"Attempting to load a Swagger {swagger_version}"
f" spec. {warning_message}" f" spec. {warning_message}"
) )
else: else:
pass raise ValueError(
elif isinstance(swagger_version, str): "Attempting to load an unsupported spec:"
logger.warning( f"\n\n{obj}\n{warning_message}"
f"Attempting to load a Swagger {swagger_version}" )
f" spec. {warning_message}"
)
else:
raise ValueError(
"Attempting to load an unsupported spec:"
f"\n\n{obj}\n{warning_message}"
)
@classmethod @classmethod
def parse_obj(cls, obj: dict) -> "OpenAPISpec": def parse_obj(cls, obj: dict) -> OpenAPISpec:
try: try:
cls._alert_unsupported_spec(obj) cls._alert_unsupported_spec(obj)
return super().parse_obj(obj) return super().parse_obj(obj)
except ValidationError as e: except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort # We are handling possibly misconfigured specs and
# job to get a reasonable interface out of it. # want to do a best-effort job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj) new_obj = copy.deepcopy(obj)
for error in e.errors(): for error in e.errors():
keys = error["loc"] keys = error["loc"]
item = new_obj item = new_obj
for key in keys[:-1]: for key in keys[:-1]:
item = item[key] item = item[key]
item.pop(keys[-1], None) item.pop(keys[-1], None)
return cls.parse_obj(new_obj) return cls.parse_obj(new_obj)
@classmethod @classmethod
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec": def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
"""Get an OpenAPI spec from a dict.""" """Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict) return cls.parse_obj(spec_dict)
@classmethod @classmethod
def from_text(cls, text: str) -> "OpenAPISpec": def from_text(cls, text: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a text.""" """Get an OpenAPI spec from a text."""
try: try:
spec_dict = json.loads(text) spec_dict = json.loads(text)
except json.JSONDecodeError: except json.JSONDecodeError:
spec_dict = yaml.safe_load(text) spec_dict = yaml.safe_load(text)
return cls.from_spec_dict(spec_dict) return cls.from_spec_dict(spec_dict)
@classmethod @classmethod
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec": def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
"""Get an OpenAPI spec from a file path.""" """Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path) path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists(): if not path_.exists():
raise FileNotFoundError(f"{path} does not exist") raise FileNotFoundError(f"{path} does not exist")
with path_.open("r") as f: with path_.open("r") as f:
return cls.from_text(f.read()) return cls.from_text(f.read())
@classmethod @classmethod
def from_url(cls, url: str) -> "OpenAPISpec": def from_url(cls, url: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a URL.""" """Get an OpenAPI spec from a URL."""
response = requests.get(url) response = requests.get(url)
return cls.from_text(response.text) return cls.from_text(response.text)
@property @property
def base_url(self) -> str: def base_url(self) -> str:
"""Get the base url.""" """Get the base url."""
return self.servers[0].url return self.servers[0].url
def get_methods_for_path(self, path: str) -> List[str]: def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path.""" """Return a list of valid methods for the specified path."""
path_item = self._get_path_strict(path) path_item = self._get_path_strict(path)
results = [] results = []
for method in HTTPVerb: for method in HTTPVerb:
operation = getattr(path_item, method.value, None) operation = getattr(path_item, method.value, None)
if isinstance(operation, Operation): if isinstance(operation, Operation):
results.append(method.value) results.append(method.value)
return results return results
def get_parameters_for_path(self, path: str) -> List[Parameter]: def get_parameters_for_path(self, path: str) -> List[Parameter]:
path_item = self._get_path_strict(path) path_item = self._get_path_strict(path)
parameters = [] parameters = []
if not path_item.parameters: if not path_item.parameters:
return [] return []
for parameter in path_item.parameters: for parameter in path_item.parameters:
if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter)
return parameters
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
if not isinstance(operation_obj, Operation):
raise ValueError(f"No {method} method found for {path}")
return operation_obj
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
parameters = []
if operation.parameters:
for parameter in operation.parameters:
if isinstance(parameter, Reference): if isinstance(parameter, Reference):
parameter = self._get_root_referenced_parameter(parameter) parameter = self._get_root_referenced_parameter(parameter)
parameters.append(parameter) parameters.append(parameter)
return parameters return parameters
def get_request_body_for_operation( def get_operation(self, path: str, method: str) -> Operation:
self, operation: Operation """Get the operation object for a given path and HTTP method."""
) -> Optional[RequestBody]: path_item = self._get_path_strict(path)
"""Get the request body for a given operation.""" operation_obj = getattr(path_item, method, None)
request_body = operation.requestBody if not isinstance(operation_obj, Operation):
if isinstance(request_body, Reference): raise ValueError(f"No {method} method found for {path}")
request_body = self._get_root_referenced_request_body(request_body) return operation_obj
return request_body
@staticmethod def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str: """Get the components for a given operation."""
"""Get a cleaned operation id from an operation id.""" parameters = []
operation_id = operation.operationId if operation.parameters:
if operation_id is None: for parameter in operation.parameters:
# Replace all punctuation of any kind with underscore if isinstance(parameter, Reference):
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/")) parameter = self._get_root_referenced_parameter(parameter)
operation_id = f"{path}_{method}" parameters.append(parameter)
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_") return parameters
def get_request_body_for_operation(
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
request_body = operation.requestBody
if isinstance(request_body, Reference):
request_body = self._get_root_referenced_request_body(request_body)
return request_body
@staticmethod
def get_cleaned_operation_id(
operation: Operation, path: str, method: str
) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
# Replace all punctuation of any kind with underscore
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")

View File

@ -4,6 +4,18 @@ import os
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Tuple from typing import Iterable, List, Tuple
import pytest
# Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
from langchain import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1:
pytest.skip(
f"Pydantic major version {_PYDANTIC_MAJOR_VERSION} is not supported.",
allow_module_level=True,
)
import pytest import pytest
import yaml import yaml
from openapi_schema_pydantic import ( from openapi_schema_pydantic import (