Wrap OpenAPI features in conditionals for pydantic v2 compatibility (#9205)

Wrap OpenAPI in conditionals for pydantic v2 compatibility.
This commit is contained in:
Eugene Yurtsev 2023-08-14 13:40:58 -04:00 committed by GitHub
parent 89be10f6b4
commit 4f1feaca83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 754 additions and 673 deletions

View File

@ -27,6 +27,12 @@ if "pydantic_v1" not in sys.modules:
# and may run prior to langchain core package.
sys.modules["pydantic_v1"] = pydantic_v1
try:
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
except metadata.PackageNotFoundError:
_PYDANTIC_MAJOR_VERSION = 0
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
from langchain.cache import BaseCache
from langchain.chains import (

View File

@ -3,9 +3,9 @@ import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
from openapi_schema_pydantic import MediaType, Parameter, Reference, RequestBody, Schema
from pydantic_v1 import BaseModel, Field
from langchain import _PYDANTIC_MAJOR_VERSION
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
logger = logging.getLogger(__name__)
@ -85,14 +85,25 @@ class APIPropertyBase(BaseModel):
"""The description of the property."""
class APIProperty(APIPropertyBase):
if _PYDANTIC_MAJOR_VERSION == 1:
from openapi_schema_pydantic import (
MediaType,
Parameter,
Reference,
RequestBody,
Schema,
)
class APIProperty(APIPropertyBase):
"""A model for a property in the query, path, header, or cookie params."""
location: APIPropertyLocation = Field(alias="location")
"""The path/how it's being passed to the endpoint."""
@staticmethod
def _cast_schema_list_type(schema: Schema) -> Optional[Union[str, Tuple[str, ...]]]:
def _cast_schema_list_type(
schema: Schema,
) -> Optional[Union[str, Tuple[str, ...]]]:
type_ = schema.type
if not isinstance(type_, list):
return type_
@ -125,7 +136,9 @@ class APIProperty(APIPropertyBase):
return schema_type
@staticmethod
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
def _get_schema_type(
parameter: Parameter, schema: Optional[Schema]
) -> SCHEMA_TYPE:
if schema is None:
return None
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
@ -136,7 +149,9 @@ class APIProperty(APIPropertyBase):
raise NotImplementedError("Objects not yet supported")
elif schema_type in PRIMITIVE_TYPES:
if schema.enum:
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
schema_type = APIProperty._get_schema_type_for_enum(
parameter, schema
)
else:
# Directly use the primitive type
pass
@ -181,7 +196,9 @@ class APIProperty(APIPropertyBase):
return False
@classmethod
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
def from_parameter(
cls, parameter: Parameter, spec: OpenAPISpec
) -> "APIProperty":
"""Instantiate from an OpenAPI Parameter."""
location = APIPropertyLocation.from_str(parameter.param_in)
cls._validate_location(
@ -201,8 +218,7 @@ class APIProperty(APIPropertyBase):
type=schema_type,
)
class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBodyProperty(APIPropertyBase):
"""A model for a request body property."""
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
@ -245,7 +261,11 @@ class APIRequestBodyProperty(APIPropertyBase):
@classmethod
def _process_array_schema(
cls, schema: Schema, name: str, spec: OpenAPISpec, references_used: List[str]
cls,
schema: Schema,
name: str,
spec: OpenAPISpec,
references_used: List[str],
) -> str:
items = schema.items
if items is not None:
@ -292,7 +312,9 @@ class APIRequestBodyProperty(APIPropertyBase):
schema, spec, references_used
)
elif schema_type == "array":
schema_type = cls._process_array_schema(schema, name, spec, references_used)
schema_type = cls._process_array_schema(
schema, name, spec, references_used
)
elif schema_type in PRIMITIVE_TYPES:
# Use the primitive type directly
pass
@ -312,8 +334,8 @@ class APIRequestBodyProperty(APIPropertyBase):
references_used=references_used,
)
class APIRequestBody(BaseModel):
# class APIRequestBodyProperty(APIPropertyBase):
class APIRequestBody(BaseModel):
"""A model for a request body."""
description: Optional[str] = Field(alias="description")
@ -392,8 +414,9 @@ class APIRequestBody(BaseModel):
media_type=media_type,
)
class APIOperation(BaseModel):
# class APIRequestBodyProperty(APIPropertyBase):
# class APIRequestBody(BaseModel):
class APIOperation(BaseModel):
"""A model for a single API operation."""
operation_id: str = Field(alias="operation_id")
@ -527,7 +550,8 @@ class APIOperation(BaseModel):
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
formatted_props.append(
f"{prop_desc}\n{' ' * indent}{prop_name}{prop_required}: {prop_type},"
f"{prop_desc}\n{' ' * indent}{prop_name}"
f"{prop_required}: {prop_type},"
)
return "\n".join(formatted_props)
@ -548,16 +572,18 @@ class APIOperation(BaseModel):
prop_type = self.ts_type_from_python(prop.type)
prop_required = "" if prop.required else "?"
prop_desc = f"/* {prop.description} */" if prop.description else ""
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
params.append(
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
)
formatted_params = "\n".join(params).strip()
description_str = f"/* {self.description} */" if self.description else ""
typescript_definition = f"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
{description_str}
type {operation_name} = (_: {{
{formatted_params}
}}) => any;
"""
return typescript_definition.strip()
@property
@ -581,3 +607,21 @@ type {operation_name} = (_: {{
if self.request_body is None:
return []
return [prop.name for prop in self.request_body.properties]
else:
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIRequestBody(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")
class APIOperation(BaseModel): # type: ignore[no-redef]
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Only supported for pydantic v1")

View File

@ -1,2 +1,4 @@
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec
__all__ = ["HTTPVerb", "OpenAPISpec"]

View File

@ -2,9 +2,8 @@
from typing import Any, Dict, List, Optional
import requests
from pydantic_v1 import Extra, root_validator
from pydantic_v1 import BaseModel, Extra, root_validator
from langchain.tools.base import BaseModel
from langchain.utils import get_from_dict_or_env

View File

@ -1,4 +1,6 @@
"""Utility functions for parsing an OpenAPI spec."""
from __future__ import annotations
import copy
import json
import logging
@ -9,19 +11,10 @@ from typing import Dict, List, Optional, Union
import requests
import yaml
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
from pydantic_v1 import ValidationError
from langchain import _PYDANTIC_MAJOR_VERSION
logger = logging.getLogger(__name__)
@ -38,7 +31,7 @@ class HTTPVerb(str, Enum):
TRACE = "trace"
@classmethod
def from_str(cls, verb: str) -> "HTTPVerb":
def from_str(cls, verb: str) -> HTTPVerb:
"""Parse an HTTP verb."""
try:
return cls(verb)
@ -46,8 +39,21 @@ class HTTPVerb(str, Enum):
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes misformatted parts of the spec."""
if _PYDANTIC_MAJOR_VERSION == 1:
from openapi_schema_pydantic import (
Components,
OpenAPI,
Operation,
Parameter,
PathItem,
Paths,
Reference,
RequestBody,
Schema,
)
class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec."""
@property
def _paths_strict(self) -> Paths:
@ -92,7 +98,9 @@ class OpenAPISpec(OpenAPI):
raise ValueError("No request body found in spec. ")
return request_bodies
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
def _get_referenced_parameter(
self, ref: Reference
) -> Union[Parameter, Reference]:
"""Get a parameter (or nested reference) or err."""
ref_name = ref.ref.split("/")[-1]
parameters = self._parameters_strict
@ -176,13 +184,13 @@ class OpenAPISpec(OpenAPI):
)
@classmethod
def parse_obj(cls, obj: dict) -> "OpenAPISpec":
def parse_obj(cls, obj: dict) -> OpenAPISpec:
try:
cls._alert_unsupported_spec(obj)
return super().parse_obj(obj)
except ValidationError as e:
# We are handling possibly misconfigured specs and want to do a best-effort
# job to get a reasonable interface out of it.
# We are handling possibly misconfigured specs and
# want to do a best-effort job to get a reasonable interface out of it.
new_obj = copy.deepcopy(obj)
for error in e.errors():
keys = error["loc"]
@ -193,12 +201,12 @@ class OpenAPISpec(OpenAPI):
return cls.parse_obj(new_obj)
@classmethod
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
"""Get an OpenAPI spec from a dict."""
return cls.parse_obj(spec_dict)
@classmethod
def from_text(cls, text: str) -> "OpenAPISpec":
def from_text(cls, text: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a text."""
try:
spec_dict = json.loads(text)
@ -207,7 +215,7 @@ class OpenAPISpec(OpenAPI):
return cls.from_spec_dict(spec_dict)
@classmethod
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
"""Get an OpenAPI spec from a file path."""
path_ = path if isinstance(path, Path) else Path(path)
if not path_.exists():
@ -216,7 +224,7 @@ class OpenAPISpec(OpenAPI):
return cls.from_text(f.read())
@classmethod
def from_url(cls, url: str) -> "OpenAPISpec":
def from_url(cls, url: str) -> OpenAPISpec:
"""Get an OpenAPI spec from a URL."""
response = requests.get(url)
return cls.from_text(response.text)
@ -275,7 +283,9 @@ class OpenAPISpec(OpenAPI):
return request_body
@staticmethod
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
def get_cleaned_operation_id(
operation: Operation, path: str, method: str
) -> str:
"""Get a cleaned operation id from an operation id."""
operation_id = operation.operationId
if operation_id is None:
@ -283,3 +293,11 @@ class OpenAPISpec(OpenAPI):
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
operation_id = f"{path}_{method}"
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
else:
class OpenAPISpec: # type: ignore[no-redef]
"""Shim for pydantic version >=2"""
def __init__(self) -> None:
raise NotImplementedError("Only supported for pydantic version 1")

View File

@ -4,6 +4,18 @@ import os
from pathlib import Path
from typing import Iterable, List, Tuple
import pytest
# Keep at top of file to ensure that pydantic test can be skipped before
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
from langchain import _PYDANTIC_MAJOR_VERSION
if _PYDANTIC_MAJOR_VERSION != 1:
pytest.skip(
f"Pydantic major version {_PYDANTIC_MAJOR_VERSION} is not supported.",
allow_module_level=True,
)
import pytest
import yaml
from openapi_schema_pydantic import (