Wrap OpenAPI features in conditionals for pydantic v2 compatibility (#9205)

Wrap OpenAPI in conditionals for pydantic v2 compatibility.
This commit is contained in:
Eugene Yurtsev 2023-08-14 13:40:58 -04:00 committed by GitHub
parent 89be10f6b4
commit 4f1feaca83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 754 additions and 673 deletions

View File

@ -27,6 +27,12 @@ if "pydantic_v1" not in sys.modules:
# and may run prior to langchain core package. # and may run prior to langchain core package.
sys.modules["pydantic_v1"] = pydantic_v1 sys.modules["pydantic_v1"] = pydantic_v1
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache from langchain.cache import BaseCache
from langchain.chains import ( from langchain.chains import (

View File

@ -3,9 +3,9 @@ import logging
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from openapi_schema_pydantic import MediaType, Parameter, Reference, RequestBody, Schema
from pydantic_v1 import BaseModel, Field from pydantic_v1 import BaseModel, Field
from langchain import _PYDANTIC_MAJOR_VERSION
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,14 +85,25 @@ class APIPropertyBase(BaseModel):
"""The description of the property.""" """The description of the property."""
class APIProperty(APIPropertyBase): if _PYDANTIC_MAJOR_VERSION == 1:
from openapi_schema_pydantic import (
MediaType,
Parameter,
Reference,
RequestBody,
Schema,
)
class APIProperty(APIPropertyBase):
"""A model for a property in the query, path, header, or cookie params.""" """A model for a property in the query, path, header, or cookie params."""
location: APIPropertyLocation = Field(alias="location") location: APIPropertyLocation = Field(alias="location")
"""The path/how it's being passed to the endpoint.""" """The path/how it's being passed to the endpoint."""
@staticmethod @staticmethod
def _cast_schema_list_type(schema: Schema) -> Optional[Union[str, Tuple[str, ...]]]: def _cast_schema_list_type(
schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]:
type_ = schema.type type_ = schema.type
if not isinstance(type_, list): if not isinstance(type_, list):
return type_ return type_
@ -125,7 +136,9 @@ class APIProperty(APIPropertyBase):
return schema_type return schema_type
@staticmethod @staticmethod
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE: def _get_schema_type(
parameter: Parameter, schema: Optional[Schema]
) -> SCHEMA_TYPE:
if schema is None: if schema is None:
return None return None
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema) schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
@ -136,7 +149,9 @@ class APIProperty(APIPropertyBase):
raise NotImplementedError("Objects not yet supported") raise NotImplementedError("Objects not yet supported")
elif schema_type in PRIMITIVE_TYPES: elif schema_type in PRIMITIVE_TYPES:
if schema.enum: if schema.enum:
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema) schema_type = APIProperty._get_schema_type_for_enum(
parameter, schema
)
else: else:
# Directly use the primitive type # Directly use the primitive type
pass pass
@ -181,7 +196,9 @@ class APIProperty(APIPropertyBase):
return False return False
@classmethod @classmethod
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty": def from_parameter(
cls, parameter: Parameter, spec: OpenAPISpec
) -> "APIProperty":
"""Instantiate from an OpenAPI Parameter.""" """Instantiate from an OpenAPI Parameter."""
location = APIPropertyLocation.from_str(parameter.param_in) location = APIPropertyLocation.from_str(parameter.param_in)
cls._validate_location( cls._validate_location(
@ -201,8 +218,7 @@ class APIProperty(APIPropertyBase):
type=schema_type, type=schema_type,
) )
class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBodyProperty(APIPropertyBase):
"""A model for a request body property.""" """A model for a request body property."""
properties: List["APIRequestBodyProperty"] = Field(alias="properties") properties: List["APIRequestBodyProperty"] = Field(alias="properties")
@ -245,7 +261,11 @@ class APIRequestBodyProperty(APIPropertyBase):
@classmethod @classmethod
def _process_array_schema( def _process_array_schema(
cls, schema: Schema, name: str, spec: OpenAPISpec, references_used: List[str] cls,
schema: Schema,
name: str,
spec: OpenAPISpec,
references_used: List[str],
) -> str: ) -> str:
items = schema.items items = schema.items
if items is not None: if items is not None:
@ -292,7 +312,9 @@ class APIRequestBodyProperty(APIPropertyBase):
schema, spec, references_used schema, spec, references_used
) )
elif schema_type == "array": elif schema_type == "array":
schema_type = cls._process_array_schema(schema, name, spec, references_used) schema_type = cls._process_array_schema(
schema, name, spec, references_used
)
elif schema_type in PRIMITIVE_TYPES: elif schema_type in PRIMITIVE_TYPES:
# Use the primitive type directly # Use the primitive type directly
pass pass
@ -312,8 +334,8 @@ class APIRequestBodyProperty(APIPropertyBase):
references_used=references_used, references_used=references_used,
) )
# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel): class APIRequestBody(BaseModel):
"""A model for a request body.""" """A model for a request body."""
description: Optional[str] = Field(alias="description") description: Optional[str] = Field(alias="description")
@ -392,8 +414,9 @@ class APIRequestBody(BaseModel):
media_type=media_type, media_type=media_type,
) )
# class APIRequestBodyProperty(APIPropertyBase):
class APIOperation(BaseModel): # class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
"""A model for a single API operation.""" """A model for a single API operation."""
operation_id: str = Field(alias="operation_id") operation_id: str = Field(alias="operation_id")
@ -527,7 +550,8 @@ class APIOperation(BaseModel):
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}" prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
formatted_props.append( formatted_props.append(
f"{prop_desc}\n{' ' * indent}{prop_name}{prop_required}: {prop_type}," f"{prop_desc}\n{' ' * indent}{prop_name}"
f"{prop_required}: {prop_type},"
) )
return "\n".join(formatted_props) return "\n".join(formatted_props)
@ -548,16 +572,18 @@ class APIOperation(BaseModel):
prop_type = self.ts_type_from_python(prop.type) prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?" prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else "" prop_desc = f"/* {prop.description} */" if prop.description else ""
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},") params.append(
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
)
formatted_params = "\n".join(params).strip() formatted_params = "\n".join(params).strip()
description_str = f"/* {self.description} */" if self.description else "" description_str = f"/* {self.description} */" if self.description else ""
typescript_definition = f""" typescript_definition = f"""
{description_str} {description_str}
type {operation_name} = (_: {{ type {operation_name} = (_: {{
{formatted_params} {formatted_params}
}}) => any; }}) => any;
""" """
return typescript_definition.strip() return typescript_definition.strip()
@property @property
@ -581,3 +607,21 @@ type {operation_name} = (_: {{
if self.request_body is None: if self.request_body is None:
return [] return []
return [prop.name for prop in self.request_body.properties] return [prop.name for prop in self.request_body.properties]
else:
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBody(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIOperation(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")

View File

@ -1,2 +1,4 @@
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat.""" """Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401 from langchain.utilities.openapi import HTTPVerb, OpenAPISpec
__all__ = ["HTTPVerb", "OpenAPISpec"]

View File

@ -2,9 +2,8 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import requests import requests
from pydantic_v1 import Extra, root_validator from pydantic_v1 import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env

View File

@ -1,4 +1,6 @@
"""Utility functions for parsing an OpenAPI spec.""" """Utility functions for parsing an OpenAPI spec."""
from __future__ import annotations
import copy import copy
import json import json
import logging import logging
@ -9,19 +11,10 @@ from typing import Dict, List, Optional, Union
import requests import requests
import yaml import yaml
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
from pydantic_v1 import ValidationError from pydantic_v1 import ValidationError
from langchain import _PYDANTIC_MAJOR_VERSION
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,7 +31,7 @@ class HTTPVerb(str, Enum):
TRACE = "trace" TRACE = "trace"
@classmethod @classmethod
def from_str(cls, verb: str) -> "HTTPVerb": def from_str(cls, verb: str) -> HTTPVerb:
"""Parse an HTTP verb.""" """Parse an HTTP verb."""
try: try:
return cls(verb) return cls(verb)
@ -46,8 +39,21 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}") raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI): if _PYDANTIC_MAJOR_VERSION == 1:
"""OpenAPI Model that removes misformatted parts of the spec.""" from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec."""
@property @property
def _paths_strict(self) -> Paths: def _paths_strict(self) -> Paths:
@ -92,7 +98,9 @@ class OpenAPISpec(OpenAPI):
raise ValueError("No request body found in spec. ") raise ValueError("No request body found in spec. ")
return request_bodies return request_bodies
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]: def _get_referenced_parameter(
self, ref: Reference
) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err.""" """Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1] ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict parameters = self._parameters_strict
@ -176,13 +184,13 @@ class OpenAPISpec(OpenAPI):
) )
@classmethod @classmethod
def parse_obj(cls, obj: dict) -> "OpenAPISpec": def parse_obj(cls, obj: dict) -> OpenAPISpec:
try: try:
cls._alert_unsupported_spec(obj) cls._alert_unsupported_spec(obj)
return super().parse_obj(obj) return super().parse_obj(obj)
except ValidationError as e: except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort # We are handling possibly misconfigured specs and
# job to get a reasonable interface out of it. # want to do a best-effort job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj) new_obj = copy.deepcopy(obj)
for error in e.errors(): for error in e.errors():
keys = error["loc"] keys = error["loc"]
@ -193,12 +201,12 @@ class OpenAPISpec(OpenAPI):
return cls.parse_obj(new_obj) return cls.parse_obj(new_obj)
@classmethod @classmethod
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec": def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
"""Get an OpenAPI spec from a dict.""" """Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict) return cls.parse_obj(spec_dict)
@classmethod @classmethod
def from_text(cls, text: str) -> "OpenAPISpec": def from_text(cls, text: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a text.""" """Get an OpenAPI spec from a text."""
try: try:
spec_dict = json.loads(text) spec_dict = json.loads(text)
@ -207,7 +215,7 @@ class OpenAPISpec(OpenAPI):
return cls.from_spec_dict(spec_dict) return cls.from_spec_dict(spec_dict)
@classmethod @classmethod
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec": def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
"""Get an OpenAPI spec from a file path.""" """Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path) path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists(): if not path_.exists():
@ -216,7 +224,7 @@ class OpenAPISpec(OpenAPI):
return cls.from_text(f.read()) return cls.from_text(f.read())
@classmethod @classmethod
def from_url(cls, url: str) -> "OpenAPISpec": def from_url(cls, url: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a URL.""" """Get an OpenAPI spec from a URL."""
response = requests.get(url) response = requests.get(url)
return cls.from_text(response.text) return cls.from_text(response.text)
@ -275,7 +283,9 @@ class OpenAPISpec(OpenAPI):
return request_body return request_body
@staticmethod @staticmethod
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str: def get_cleaned_operation_id(
operation: Operation, path: str, method: str
) -> str:
"""Get a cleaned operation id from an operation id.""" """Get a cleaned operation id from an operation id."""
operation_id = operation.operationId operation_id = operation.operationId
if operation_id is None: if operation_id is None:
@ -283,3 +293,11 @@ class OpenAPISpec(OpenAPI):
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/")) path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}" operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_") return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")

View File

@ -4,6 +4,18 @@ import os
from pathlib import Path from pathlib import Path
from typing import Iterable, List, Tuple from typing import Iterable, List, Tuple
import pytest
# Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
from langchain import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1:
pytest.skip(
f"Pydantic major version {_PYDANTIC_MAJOR_VERSION} is not supported.",
allow_module_level=True,
)
import pytest import pytest
import yaml import yaml
from openapi_schema_pydantic import ( from openapi_schema_pydantic import (