Pydantic v2 support for OpenAPI Specs (#11936)

- **Description:** Adding Pydantic v2 support for OpenAPI Specs 

- **Issue:**
- OpenAPI spec support was disabled because `openapi-schema-pydantic`
doesn't support Pydantic v2:
     #9205
     
     - Caused errors in `get_openapi_chain`
   
    - This may be the cause of #9520.

- **Tag maintainer:** @eyurtsev
- **Twitter handle:** kreneskyp


The root cause was that `openapi-schema-pydantic` hasn't been updated in
some time but
[openapi-pydantic](https://github.com/mike-oakley/openapi-pydantic)
forked and updated the project.
This commit is contained in:
Peter Krenesky 2023-10-19 08:06:11 -07:00 committed by GitHub
parent 4adabd33ac
commit 8425f33363
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 1878 additions and 1815 deletions

View File

@ -22,7 +22,7 @@ from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text
if TYPE_CHECKING:
from openapi_schema_pydantic import Parameter
from openapi_pydantic import Parameter
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:

View File

@ -15,7 +15,7 @@ from typing import (
Union,
)
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, BaseModel, Field
from langchain.pydantic_v1 import BaseModel, Field
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
logger = logging.getLogger(__name__)
@ -95,16 +95,16 @@ class APIPropertyBase(BaseModel):
"""The description of the property."""
if _PYDANTIC_MAJOR_VERSION == 1:
if TYPE_CHECKING:
from openapi_schema_pydantic import (
if TYPE_CHECKING:
from openapi_pydantic import (
MediaType,
Parameter,
RequestBody,
Schema,
)
class APIProperty(APIPropertyBase):
class APIProperty(APIPropertyBase):
"""A model for a property in the query, path, header, or cookie params."""
location: APIPropertyLocation = Field(alias="location")
@ -130,7 +130,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_schema_type_for_array(
schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]:
from openapi_schema_pydantic import (
from openapi_pydantic import (
Reference,
Schema,
)
@ -151,9 +151,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return schema_type
@staticmethod
def _get_schema_type(
parameter: Parameter, schema: Optional[Schema]
) -> SCHEMA_TYPE:
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
if schema is None:
return None
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
@ -164,9 +162,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
raise NotImplementedError("Objects not yet supported")
elif schema_type in PRIMITIVE_TYPES:
if schema.enum:
schema_type = APIProperty._get_schema_type_for_enum(
parameter, schema
)
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
else:
# Directly use the primitive type
pass
@ -192,7 +188,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
@staticmethod
def _get_schema(parameter: Parameter, spec: OpenAPISpec) -> Optional[Schema]:
from openapi_schema_pydantic import (
from openapi_pydantic import (
Reference,
Schema,
)
@ -216,9 +212,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return False
@classmethod
def from_parameter(
cls, parameter: Parameter, spec: OpenAPISpec
) -> "APIProperty":
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
"""Instantiate from an OpenAPI Parameter."""
location = APIPropertyLocation.from_str(parameter.param_in)
cls._validate_location(
@ -238,7 +232,8 @@ if _PYDANTIC_MAJOR_VERSION == 1:
type=schema_type,
)
class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBodyProperty(APIPropertyBase):
"""A model for a request body property."""
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
@ -253,7 +248,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _process_object_schema(
cls, schema: Schema, spec: OpenAPISpec, references_used: List[str]
) -> Tuple[Union[str, List[str], None], List["APIRequestBodyProperty"]]:
from openapi_schema_pydantic import (
from openapi_pydantic import (
Reference,
)
@ -291,7 +286,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
spec: OpenAPISpec,
references_used: List[str],
) -> str:
from openapi_schema_pydantic import Reference, Schema
from openapi_pydantic import Reference, Schema
items = schema.items
if items is not None:
@ -338,9 +333,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
schema, spec, references_used
)
elif schema_type == "array":
schema_type = cls._process_array_schema(
schema, name, spec, references_used
)
schema_type = cls._process_array_schema(schema, name, spec, references_used)
elif schema_type in PRIMITIVE_TYPES:
# Use the primitive type directly
pass
@ -360,8 +353,9 @@ if _PYDANTIC_MAJOR_VERSION == 1:
references_used=references_used,
)
# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel):
# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel):
"""A model for a request body."""
description: Optional[str] = Field(alias="description")
@ -380,7 +374,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
spec: OpenAPISpec,
) -> List[APIRequestBodyProperty]:
"""Process the media type of the request body."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
references_used = []
schema = media_type_obj.media_type_schema
@ -442,9 +436,10 @@ if _PYDANTIC_MAJOR_VERSION == 1:
media_type=media_type,
)
# class APIRequestBodyProperty(APIPropertyBase):
# class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
# class APIRequestBodyProperty(APIPropertyBase):
# class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
"""A model for a single API operation."""
operation_id: str = Field(alias="operation_id")
@ -600,18 +595,16 @@ if _PYDANTIC_MAJOR_VERSION == 1:
prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else ""
params.append(
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
)
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
formatted_params = "\n".join(params).strip()
description_str = f"/* {self.description} */" if self.description else ""
typescript_definition = f"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
return typescript_definition.strip()
@property
@ -635,21 +628,3 @@ if _PYDANTIC_MAJOR_VERSION == 1:
if self.request_body is None:
return []
return [prop.name for prop in self.request_body.properties]
else:
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBody(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIOperation(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")

View File

@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union
import requests
import yaml
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION, ValidationError
from langchain.pydantic_v1 import ValidationError
logger = logging.getLogger(__name__)
@ -38,9 +38,8 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
if _PYDANTIC_MAJOR_VERSION == 1:
if TYPE_CHECKING:
from openapi_schema_pydantic import (
if TYPE_CHECKING:
from openapi_pydantic import (
Components,
Operation,
Parameter,
@ -51,14 +50,17 @@ if _PYDANTIC_MAJOR_VERSION == 1:
Schema,
)
try:
from openapi_schema_pydantic import OpenAPI
except ImportError:
try:
from openapi_pydantic import OpenAPI
except ImportError:
OpenAPI = object # type: ignore
class OpenAPISpec(OpenAPI):
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec."""
openapi: str = "3.1.0" # overriding overly restrictive type from parent class
@property
def _paths_strict(self) -> Paths:
if not self.paths:
@ -102,9 +104,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(
self, ref: Reference
) -> Union[Parameter, Reference]:
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
@ -114,7 +114,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_root_referenced_parameter(self, ref: Reference) -> Parameter:
"""Get the root reference or err."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
parameter = self._get_referenced_parameter(ref)
while isinstance(parameter, Reference):
@ -130,7 +130,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return schemas[ref_name]
def get_schema(self, schema: Union[Reference, Schema]) -> Schema:
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
if isinstance(schema, Reference):
return self.get_referenced_schema(schema)
@ -138,7 +138,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def _get_root_referenced_schema(self, ref: Reference) -> Schema:
"""Get the root reference or err."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
schema = self.get_referenced_schema(ref)
while isinstance(schema, Reference):
@ -159,7 +159,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
self, ref: Reference
) -> Optional[RequestBody]:
"""Get the root request Body or err."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
request_body = self._get_referenced_request_body(ref)
while isinstance(request_body, Reference):
@ -248,7 +248,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_methods_for_path(self, path: str) -> List[str]:
"""Return a list of valid methods for the specified path."""
from openapi_schema_pydantic import Operation
from openapi_pydantic import Operation
path_item = self._get_path_strict(path)
results = []
@ -259,7 +259,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return results
def get_parameters_for_path(self, path: str) -> List[Parameter]:
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
path_item = self._get_path_strict(path)
parameters = []
@ -273,7 +273,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_operation(self, path: str, method: str) -> Operation:
"""Get the operation object for a given path and HTTP method."""
from openapi_schema_pydantic import Operation
from openapi_pydantic import Operation
path_item = self._get_path_strict(path)
operation_obj = getattr(path_item, method, None)
@ -283,7 +283,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
def get_parameters_for_operation(self, operation: Operation) -> List[Parameter]:
"""Get the components for a given operation."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
parameters = []
if operation.parameters:
@ -297,7 +297,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
self, operation: Operation
) -> Optional[RequestBody]:
"""Get the request body for a given operation."""
from openapi_schema_pydantic import Reference
from openapi_pydantic import Reference
request_body = operation.requestBody
if isinstance(request_body, Reference):
@ -305,9 +305,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
return request_body
@staticmethod
def get_cleaned_operation_id(
operation: Operation, path: str, method: str
) -> str:
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
@ -315,11 +313,3 @@ if _PYDANTIC_MAJOR_VERSION == 1:
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")

File diff suppressed because it is too large Load Diff

View File

@ -20,7 +20,7 @@ PyYAML = ">=5.3"
numpy = "^1"
azure-core = {version = "^1.26.4", optional=true}
tqdm = {version = ">=4.48.0", optional = true}
openapi-schema-pydantic = {version = "^1.2", optional = true}
openapi-pydantic = {version = "^0.3.2", optional = true}
faiss-cpu = {version = "^1", optional = true}
wikipedia = {version = "^1", optional = true}
elasticsearch = {version = "^8", optional = true}
@ -359,7 +359,7 @@ extended_testing = [
"xata",
"xmltodict",
"faiss-cpu",
"openapi-schema-pydantic",
"openapi-pydantic",
"markdownify",
"arxiv",
"dashvector",

View File

@ -7,7 +7,7 @@ from typing import Iterable, List, Tuple
import pytest
# Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
# pydantic v1 related imports are attempted by openapi_pydantic.
from langchain.pydantic_v1 import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1:
@ -78,7 +78,7 @@ def http_paths_and_methods() -> List[Tuple[str, OpenAPISpec, str, str]]:
return http_paths_and_methods
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_parse_api_operations() -> None:
"""Test the APIOperation class."""
for spec_name, spec, path, method in http_paths_and_methods():
@ -88,21 +88,21 @@ def test_parse_api_operations() -> None:
raise AssertionError(f"Error processing {spec_name}: {e} ") from e
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
@pytest.fixture
def raw_spec() -> OpenAPISpec:
"""Return a raw OpenAPI spec."""
from openapi_schema_pydantic import Info
from openapi_pydantic import Info
return OpenAPISpec(
info=Info(title="Test API", version="1.0.0"),
)
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a reference."""
from openapi_schema_pydantic import (
from openapi_pydantic import (
Components,
MediaType,
Reference,
@ -140,10 +140,10 @@ def test_api_request_body_from_request_body_with_ref(raw_spec: OpenAPISpec) -> N
assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -> None:
"""Test instantiating APIRequestBody from RequestBody with a schema."""
from openapi_schema_pydantic import (
from openapi_pydantic import (
MediaType,
RequestBody,
Schema,
@ -171,9 +171,9 @@ def test_api_request_body_from_request_body_with_schema(raw_spec: OpenAPISpec) -
assert api_request_body.media_type == "application/json"
@pytest.mark.requires("openapi_schema_pydantic")
@pytest.mark.requires("openapi_pydantic")
def test_api_request_body_property_from_schema(raw_spec: OpenAPISpec) -> None:
from openapi_schema_pydantic import (
from openapi_pydantic import (
Components,
Reference,
Schema,