Pydantic v2 support for OpenAPI Specs (#11936)

- **Description:** Adding Pydantic v2 support for OpenAPI Specs 

- **Issue:**
- OpenAPI spec support was disabled because `openapi-schema-pydantic`
doesn't support Pydantic v2:
     #9205
     
     - Caused errors in `get_openapi_chain`
   
    - This may be the cause of #9520.

- **Tag maintainer:** @eyurtsev
- **Twitter handle:** kreneskyp


The root cause was that `openapi-schema-pydantic` hasn't been updated in
some time but
[openapi-pydantic](https://github.com/mike-oakley/openapi-pydantic)
forked and updated the project.
This commit is contained in:
Peter Krenesky 2023-10-19 08:06:11 -07:00 committed by GitHub
parent 4adabd33ac
commit 8425f33363
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1878 additions and 1815 deletions

View File

@ -22,7 +22,7 @@ from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text from langchain.utils.input import get_colored_text
if TYPE_CHECKING: if TYPE_CHECKING:
from openapi_schema_pydantic import Parameter from openapi_pydantic import Parameter
def _get_description(o: Any, prefer_short: bool) -> Optional[str]: def _get_description(o: Any, prefer_short: bool) -> Optional[str]:

View File

@ -15,7 +15,7 @@ from typing import (
Union, Union,
) )
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel, Field from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -95,16 +95,16 @@ class APIPropertyBase(BaseModel):
"""The description of the property.""" """The description of the property."""
if _PYDANTIC_MAJOR_VERSION == 1: if TYPE_CHECKING:
if TYPE_CHECKING: from openapi_pydantic import (
from openapi_schema_pydantic import (
MediaType, MediaType,
Parameter, Parameter,
RequestBody, RequestBody,
Schema, Schema,
) )
class APIProperty(APIPropertyBase):
class APIProperty(APIPropertyBase):
"""A model for a property in the query, path, header, or cookie params.""" """A model for a property in the query, path, header, or cookie params."""
location: APIPropertyLocation = Field(alias="location") location: APIPropertyLocation = Field(alias="location")
@ -130,7 +130,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_schema_type_for_array( def _get_schema_type_for_array(
schema: Schema, schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]: ) -> Optional[Union[str, Tuple[str, ...]]]:
from openapi_schema_pydantic import ( from openapi_pydantic import (
Reference, Reference,
Schema, Schema,
) )
@ -151,9 +151,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return schema_type return schema_type
@staticmethod @staticmethod
def _get_schema_type( def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
parameter: Parameter, schema: Optional[Schema]
) -> SCHEMA_TYPE:
if schema is None: if schema is None:
return None return None
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema) schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
@ -164,9 +162,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
raise NotImplementedError("Objects not yet supported") raise NotImplementedError("Objects not yet supported")
elif schema_type in PRIMITIVE_TYPES: elif schema_type in PRIMITIVE_TYPES:
if schema.enum: if schema.enum:
schema_type = APIProperty._get_schema_type_for_enum( schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
parameter, schema
)
else: else:
# Directly use the primitive type # Directly use the primitive type
pass pass
@ -192,7 +188,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
@staticmethod @staticmethod
def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]: def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]:
from openapi_schema_pydantic import ( from openapi_pydantic import (
Reference, Reference,
Schema, Schema,
) )
@ -216,9 +212,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return False return False
@classmethod @classmethod
def from_parameter( def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
cls, parameter: Parameter, spec: OpenAPISpec
) -> "APIProperty":
"""Instantiate from an OpenAPI Parameter.""" """Instantiate from an OpenAPI Parameter."""
location = APIPropertyLocation.from_str(parameter.param_in) location = APIPropertyLocation.from_str(parameter.param_in)
cls._validate_location( cls._validate_location(
@ -238,7 +232,8 @@ if _PYDANTIC_MAJOR_VERSION == 1:
type=schema_type, type=schema_type,
) )
class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBodyProperty(APIPropertyBase):
"""A model for a request body property.""" """A model for a request body property."""
properties: List["APIRequestBodyProperty"] = Field(alias="properties") properties: List["APIRequestBodyProperty"] = Field(alias="properties")
@ -253,7 +248,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _process_object_schema( def _process_object_schema(
cls, schema: Schema, spec: OpenAPISpec, references_used: List[str] cls, schema: Schema, spec: OpenAPISpec, references_used: List[str]
) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]: ) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]:
from openapi_schema_pydantic import ( from openapi_pydantic import (
Reference, Reference,
) )
@ -291,7 +286,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
spec: OpenAPISpec, spec: OpenAPISpec,
references_used: List[str], references_used: List[str],
) -> str: ) -> str:
from openapi_schema_pydantic import Reference, Schema from openapi_pydantic import Reference, Schema
items = schema.items items = schema.items
if items is not None: if items is not None:
@ -338,9 +333,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
schema, spec, references_used schema, spec, references_used
) )
elif schema_type == "array": elif schema_type == "array":
schema_type = cls._process_array_schema( schema_type = cls._process_array_schema(schema, name, spec, references_used)
schema, name, spec, references_used
)
elif schema_type in PRIMITIVE_TYPES: elif schema_type in PRIMITIVE_TYPES:
# Use the primitive type directly # Use the primitive type directly
pass pass
@ -360,8 +353,9 @@ if _PYDANTIC_MAJOR_VERSION == 1:
references_used=references_used, references_used=references_used,
) )
# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel): # class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel):
"""A model for a request body.""" """A model for a request body."""
description: Optional[str] = Field(alias="description") description: Optional[str] = Field(alias="description")
@ -380,7 +374,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
spec: OpenAPISpec, spec: OpenAPISpec,
) -> List[APIRequestBodyProperty]: ) -> List[APIRequestBodyProperty]:
"""Process the media type of the request body.""" """Process the media type of the request body."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
references_used = [] references_used = []
schema = media_type_obj.media_type_schema schema = media_type_obj.media_type_schema
@ -442,9 +436,10 @@ if _PYDANTIC_MAJOR_VERSION == 1:
media_type=media_type, media_type=media_type,
) )
# class APIRequestBodyProperty(APIPropertyBase):
# class APIRequestBody(BaseModel): # class APIRequestBodyProperty(APIPropertyBase):
class APIOperation(BaseModel): # class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
"""A model for a single API operation.""" """A model for a single API operation."""
operation_id: str = Field(alias="operation_id") operation_id: str = Field(alias="operation_id")
@ -600,18 +595,16 @@ if _PYDANTIC_MAJOR_VERSION == 1:
prop_type = self.ts_type_from_python(prop.type) prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?" prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else "" prop_desc = f"/* {prop.description} */" if prop.description else ""
params.append( params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
)
formatted_params = "\n".join(params).strip() formatted_params = "\n".join(params).strip()
description_str = f"/* {self.description} */" if self.description else "" description_str = f"/* {self.description} */" if self.description else ""
typescript_definition = f""" typescript_definition = f"""
{description_str} {description_str}
type {operation_name} = (_: {{ type {operation_name} = (_: {{
{formatted_params} {formatted_params}
}}) => any; }}) => any;
""" """
return typescript_definition.strip() return typescript_definition.strip()
@property @property
@ -635,21 +628,3 @@ if _PYDANTIC_MAJOR_VERSION == 1:
if self.request_body is None: if self.request_body is None:
return [] return []
return [prop.name for prop in self.request_body.properties] return [prop.name for prop in self.request_body.properties]
else:
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBody(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIOperation(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")

View File

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import requests import requests
import yaml import yaml
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, ValidationError from langchain.pydantic_v1 import ValidationError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,9 +38,8 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}") raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
if _PYDANTIC_MAJOR_VERSION == 1: if TYPE_CHECKING:
if TYPE_CHECKING: from openapi_pydantic import (
from openapi_schema_pydantic import (
Components, Components,
Operation, Operation,
Parameter, Parameter,
@ -51,14 +50,17 @@ if _PYDANTIC_MAJOR_VERSION == 1:
Schema, Schema,
) )
try: try:
from openapi_schema_pydantic import OpenAPI from openapi_pydantic import OpenAPI
except ImportError: except ImportError:
OpenAPI = object # type: ignore OpenAPI = object # type: ignore
class OpenAPISpec(OpenAPI):
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec.""" """OpenAPI Model that removes mis-formatted parts of the spec."""
openapi: str = "3.1.0" # overriding overly restrictive type from parent class
@property @property
def _paths_strict(self) -> Paths: def _paths_strict(self) -> Paths:
if not self.paths: if not self.paths:
@ -102,9 +104,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
raise ValueError("No request body found in spec. ") raise ValueError("No request body found in spec. ")
return request_bodies return request_bodies
def _get_referenced_parameter( def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
self, ref: Reference
) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err.""" """Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1] ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict parameters = self._parameters_strict
@ -114,7 +114,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter: def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err.""" """Get the root reference or err."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
parameter = self._get_referenced_parameter(ref) parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference): while isinstance(parameter, Reference):
@ -130,7 +130,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return schemas[ref_name] return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema: def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
if isinstance(schema, Reference): if isinstance(schema, Reference):
return self.get_referenced_schema(schema) return self.get_referenced_schema(schema)
@ -138,7 +138,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_root_referenced_schema(self, ref: Reference) -> Schema: def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err.""" """Get the root reference or err."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
schema = self.get_referenced_schema(ref) schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference): while isinstance(schema, Reference):
@ -159,7 +159,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
self, ref: Reference self, ref: Reference
) -> Optional[RequestBody]: ) -> Optional[RequestBody]:
"""Get the root request Body or err.""" """Get the root request Body or err."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
request_body = self._get_referenced_request_body(ref) request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference): while isinstance(request_body, Reference):
@ -248,7 +248,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_methods_for_path(self, path: str) -> List[str]: def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path.""" """Return a list of valid methods for the specified path."""
from openapi_schema_pydantic import Operation from openapi_pydantic import Operation
path_item = self._get_path_strict(path) path_item = self._get_path_strict(path)
results = [] results = []
@ -259,7 +259,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return results return results
def get_parameters_for_path(self, path: str) -> List[Parameter]: def get_parameters_for_path(self, path: str) -> List[Parameter]:
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
path_item = self._get_path_strict(path) path_item = self._get_path_strict(path)
parameters = [] parameters = []
@ -273,7 +273,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_operation(self, path: str, method: str) -> Operation: def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method.""" """Get the operation object for a given path and HTTP method."""
from openapi_schema_pydantic import Operation from openapi_pydantic import Operation
path_item = self._get_path_strict(path) path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None) operation_obj = getattr(path_item, method, None)
@ -283,7 +283,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]: def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation.""" """Get the components for a given operation."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
parameters = [] parameters = []
if operation.parameters: if operation.parameters:
@ -297,7 +297,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
self, operation: Operation self, operation: Operation
) -> Optional[RequestBody]: ) -> Optional[RequestBody]:
"""Get the request body for a given operation.""" """Get the request body for a given operation."""
from openapi_schema_pydantic import Reference from openapi_pydantic import Reference
request_body = operation.requestBody request_body = operation.requestBody
if isinstance(request_body, Reference): if isinstance(request_body, Reference):
@ -305,9 +305,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return request_body return request_body
@staticmethod @staticmethod
def get_cleaned_operation_id( def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
operation: Operation, path: str, method: str
) -> str:
"""Get a cleaned operation id from an operation id.""" """Get a cleaned operation id from an operation id."""
operation_id = operation.operationId operation_id = operation.operationId
if operation_id is None: if operation_id is None:
@ -315,11 +313,3 @@ if _PYDANTIC_MAJOR_VERSION == 1:
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/")) path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}" operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_") return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,7 @@ PyYAML = ">=5.3"
numpy = "^1" numpy = "^1"
azure-core = {version = "^1.26.4", optional=true} azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true} tqdm = {version = ">=4.48.0", optional = true}
openapi-schema-pydantic = {version = "^1.2", optional = true} openapi-pydantic = {version = "^0.3.2", optional = true}
faiss-cpu = {version = "^1", optional = true} faiss-cpu = {version = "^1", optional = true}
wikipedia = {version = "^1", optional = true} wikipedia = {version = "^1", optional = true}
elasticsearch = {version = "^8", optional = true} elasticsearch = {version = "^8", optional = true}
@ -359,7 +359,7 @@ extended_testing = [
"xata", "xata",
"xmltodict", "xmltodict",
"faiss-cpu", "faiss-cpu",
"openapi-schema-pydantic", "openapi-pydantic",
"markdownify", "markdownify",
"arxiv", "arxiv",
"dashvector", "dashvector",

View File

@ -7,7 +7,7 @@ from typing import Iterable, List, Tuple
import pytest import pytest
# Keep at top of file to ensure that pydantic test can be skipped before # Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic. # pydantic v1 related imports are attempted by openapi_pydantic.
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1: if _PYDANTIC_MAJOR_VERSION != 1:
@ -78,7 +78,7 @@ def http_paths_and_methods() -> List[Tuple[str, OpenAPISpec, str, str]]:
return http_paths_and_methods return http_paths_and_methods
@pytest.mark.requires("openapi_schema_pydantic") @pytest.mark.requires("openapi_pydantic")
def test_parse_api_operations() -> None: def test_parse_api_operations() -> None:
"""Test the APIOperation class.""" """Test the APIOperation class."""
for spec_name, spec, path, method in http_paths_and_methods(): for spec_name, spec, path, method in http_paths_and_methods():
@ -88,21 +88,21 @@ def test_parse_api_operations() -> None:
raise AssertionError(f"Error processing {spec_name}: {e} ") from e raise AssertionError(f"Error processing {spec_name}: {e} ") from e
@pytest.mark.requires("openapi_schema_pydantic") @pytest.mark.requires("openapi_pydantic")
@pytest.fixture @pytest.fixture
def raw_spec() -> OpenAPISpec: def raw_spec() -> OpenAPISpec:
"""Return a raw OpenAPI spec.""" """Return a raw OpenAPI spec."""
from openapi_schema_pydantic import Info from openapi_pydantic import Info
return OpenAPISpec( return OpenAPISpec(
info=Info(title="Test API", version="1.0.0"), info=Info(title="Test API", version="1.0.0"),
) )
@pytest.mark.requires("openapi_schema_pydantic") @pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> None: def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a reference.""" """Test instantiating APIRequestBody from RequestBody with a reference."""
from openapi_schema_pydantic import ( from openapi_pydantic import (
Components, Components,
MediaType, MediaType,
Reference, Reference,
@ -140,10 +140,10 @@ def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> N
assert api_request_body.media_type == "application/json" assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic") @pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -> None: def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a schema.""" """Test instantiating APIRequestBody from RequestBody with a schema."""
from openapi_schema_pydantic import ( from openapi_pydantic import (
MediaType, MediaType,
RequestBody, RequestBody,
Schema, Schema,
@ -171,9 +171,9 @@ def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -
assert api_request_body.media_type == "application/json" assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic") @pytest.mark.requires("openapi_pydantic")
def test_api_request_body_property_from_schema(raw_spec: OpenAPISpec) -> None: def test_api_request_body_property_from_schema(raw_spec: OpenAPISpec) -> None:
from openapi_schema_pydantic import ( from openapi_pydantic import (
Components, Components,
Reference, Reference,
Schema, Schema,