mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
Wrap OpenAPI features in conditionals for pydantic v2 compatibility (#9205)
Wrap OpenAPI in conditionals for pydantic v2 compatibility.
This commit is contained in:
parent
89be10f6b4
commit
4f1feaca83
@ -27,6 +27,12 @@ if "pydantic_v1" not in sys.modules:
|
||||
# and may run prior to langchain core package.
|
||||
sys.modules["pydantic_v1"] = pydantic_v1
|
||||
|
||||
try:
|
||||
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||
except metadata.PackageNotFoundError:
|
||||
_PYDANTIC_MAJOR_VERSION = 0
|
||||
|
||||
|
||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||
from langchain.cache import BaseCache
|
||||
from langchain.chains import (
|
||||
|
@ -3,9 +3,9 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from openapi_schema_pydantic import MediaType, Parameter, Reference, RequestBody, Schema
|
||||
from pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -85,14 +85,25 @@ class APIPropertyBase(BaseModel):
|
||||
"""The description of the property."""
|
||||
|
||||
|
||||
class APIProperty(APIPropertyBase):
|
||||
if _PYDANTIC_MAJOR_VERSION == 1:
|
||||
from openapi_schema_pydantic import (
|
||||
MediaType,
|
||||
Parameter,
|
||||
Reference,
|
||||
RequestBody,
|
||||
Schema,
|
||||
)
|
||||
|
||||
class APIProperty(APIPropertyBase):
|
||||
"""A model for a property in the query, path, header, or cookie params."""
|
||||
|
||||
location: APIPropertyLocation = Field(alias="location")
|
||||
"""The path/how it's being passed to the endpoint."""
|
||||
|
||||
@staticmethod
|
||||
def _cast_schema_list_type(schema: Schema) -> Optional[Union[str, Tuple[str, ...]]]:
|
||||
def _cast_schema_list_type(
|
||||
schema: Schema,
|
||||
) -> Optional[Union[str, Tuple[str, ...]]]:
|
||||
type_ = schema.type
|
||||
if not isinstance(type_, list):
|
||||
return type_
|
||||
@ -125,7 +136,9 @@ class APIProperty(APIPropertyBase):
|
||||
return schema_type
|
||||
|
||||
@staticmethod
|
||||
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
|
||||
def _get_schema_type(
|
||||
parameter: Parameter, schema: Optional[Schema]
|
||||
) -> SCHEMA_TYPE:
|
||||
if schema is None:
|
||||
return None
|
||||
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
|
||||
@ -136,7 +149,9 @@ class APIProperty(APIPropertyBase):
|
||||
raise NotImplementedError("Objects not yet supported")
|
||||
elif schema_type in PRIMITIVE_TYPES:
|
||||
if schema.enum:
|
||||
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
|
||||
schema_type = APIProperty._get_schema_type_for_enum(
|
||||
parameter, schema
|
||||
)
|
||||
else:
|
||||
# Directly use the primitive type
|
||||
pass
|
||||
@ -181,7 +196,9 @@ class APIProperty(APIPropertyBase):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
|
||||
def from_parameter(
|
||||
cls, parameter: Parameter, spec: OpenAPISpec
|
||||
) -> "APIProperty":
|
||||
"""Instantiate from an OpenAPI Parameter."""
|
||||
location = APIPropertyLocation.from_str(parameter.param_in)
|
||||
cls._validate_location(
|
||||
@ -201,8 +218,7 @@ class APIProperty(APIPropertyBase):
|
||||
type=schema_type,
|
||||
)
|
||||
|
||||
|
||||
class APIRequestBodyProperty(APIPropertyBase):
|
||||
class APIRequestBodyProperty(APIPropertyBase):
|
||||
"""A model for a request body property."""
|
||||
|
||||
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
|
||||
@ -245,7 +261,11 @@ class APIRequestBodyProperty(APIPropertyBase):
|
||||
|
||||
@classmethod
|
||||
def _process_array_schema(
|
||||
cls, schema: Schema, name: str, spec: OpenAPISpec, references_used: List[str]
|
||||
cls,
|
||||
schema: Schema,
|
||||
name: str,
|
||||
spec: OpenAPISpec,
|
||||
references_used: List[str],
|
||||
) -> str:
|
||||
items = schema.items
|
||||
if items is not None:
|
||||
@ -292,7 +312,9 @@ class APIRequestBodyProperty(APIPropertyBase):
|
||||
schema, spec, references_used
|
||||
)
|
||||
elif schema_type == "array":
|
||||
schema_type = cls._process_array_schema(schema, name, spec, references_used)
|
||||
schema_type = cls._process_array_schema(
|
||||
schema, name, spec, references_used
|
||||
)
|
||||
elif schema_type in PRIMITIVE_TYPES:
|
||||
# Use the primitive type directly
|
||||
pass
|
||||
@ -312,8 +334,8 @@ class APIRequestBodyProperty(APIPropertyBase):
|
||||
references_used=references_used,
|
||||
)
|
||||
|
||||
|
||||
class APIRequestBody(BaseModel):
|
||||
# class APIRequestBodyProperty(APIPropertyBase):
|
||||
class APIRequestBody(BaseModel):
|
||||
"""A model for a request body."""
|
||||
|
||||
description: Optional[str] = Field(alias="description")
|
||||
@ -392,8 +414,9 @@ class APIRequestBody(BaseModel):
|
||||
media_type=media_type,
|
||||
)
|
||||
|
||||
|
||||
class APIOperation(BaseModel):
|
||||
# class APIRequestBodyProperty(APIPropertyBase):
|
||||
# class APIRequestBody(BaseModel):
|
||||
class APIOperation(BaseModel):
|
||||
"""A model for a single API operation."""
|
||||
|
||||
operation_id: str = Field(alias="operation_id")
|
||||
@ -527,7 +550,8 @@ class APIOperation(BaseModel):
|
||||
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
|
||||
|
||||
formatted_props.append(
|
||||
f"{prop_desc}\n{' ' * indent}{prop_name}{prop_required}: {prop_type},"
|
||||
f"{prop_desc}\n{' ' * indent}{prop_name}"
|
||||
f"{prop_required}: {prop_type},"
|
||||
)
|
||||
|
||||
return "\n".join(formatted_props)
|
||||
@ -548,16 +572,18 @@ class APIOperation(BaseModel):
|
||||
prop_type = self.ts_type_from_python(prop.type)
|
||||
prop_required = "" if prop.required else "?"
|
||||
prop_desc = f"/* {prop.description} */" if prop.description else ""
|
||||
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
|
||||
params.append(
|
||||
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
|
||||
)
|
||||
|
||||
formatted_params = "\n".join(params).strip()
|
||||
description_str = f"/* {self.description} */" if self.description else ""
|
||||
typescript_definition = f"""
|
||||
{description_str}
|
||||
type {operation_name} = (_: {{
|
||||
{formatted_params}
|
||||
}}) => any;
|
||||
"""
|
||||
{description_str}
|
||||
type {operation_name} = (_: {{
|
||||
{formatted_params}
|
||||
}}) => any;
|
||||
"""
|
||||
return typescript_definition.strip()
|
||||
|
||||
@property
|
||||
@ -581,3 +607,21 @@ type {operation_name} = (_: {{
|
||||
if self.request_body is None:
|
||||
return []
|
||||
return [prop.name for prop in self.request_body.properties]
|
||||
|
||||
else:
|
||||
|
||||
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic v1")
|
||||
|
||||
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic v1")
|
||||
|
||||
class APIRequestBody(BaseModel): # type: ignore[no-redef]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic v1")
|
||||
|
||||
class APIOperation(BaseModel): # type: ignore[no-redef]
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic v1")
|
||||
|
@ -1,2 +1,4 @@
|
||||
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
|
||||
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401
|
||||
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec
|
||||
|
||||
__all__ = ["HTTPVerb", "OpenAPISpec"]
|
||||
|
@ -2,9 +2,8 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic_v1 import Extra, root_validator
|
||||
from pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.tools.base import BaseModel
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Utility functions for parsing an OpenAPI spec."""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
@ -9,19 +11,10 @@ from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
from openapi_schema_pydantic import (
|
||||
Components,
|
||||
OpenAPI,
|
||||
Operation,
|
||||
Parameter,
|
||||
PathItem,
|
||||
Paths,
|
||||
Reference,
|
||||
RequestBody,
|
||||
Schema,
|
||||
)
|
||||
from pydantic_v1 import ValidationError
|
||||
|
||||
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -38,7 +31,7 @@ class HTTPVerb(str, Enum):
|
||||
TRACE = "trace"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, verb: str) -> "HTTPVerb":
|
||||
def from_str(cls, verb: str) -> HTTPVerb:
|
||||
"""Parse an HTTP verb."""
|
||||
try:
|
||||
return cls(verb)
|
||||
@ -46,8 +39,21 @@ class HTTPVerb(str, Enum):
|
||||
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
|
||||
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
"""OpenAPI Model that removes misformatted parts of the spec."""
|
||||
if _PYDANTIC_MAJOR_VERSION == 1:
|
||||
from openapi_schema_pydantic import (
|
||||
Components,
|
||||
OpenAPI,
|
||||
Operation,
|
||||
Parameter,
|
||||
PathItem,
|
||||
Paths,
|
||||
Reference,
|
||||
RequestBody,
|
||||
Schema,
|
||||
)
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||
|
||||
@property
|
||||
def _paths_strict(self) -> Paths:
|
||||
@ -92,7 +98,9 @@ class OpenAPISpec(OpenAPI):
|
||||
raise ValueError("No request body found in spec. ")
|
||||
return request_bodies
|
||||
|
||||
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
|
||||
def _get_referenced_parameter(
|
||||
self, ref: Reference
|
||||
) -> Union[Parameter, Reference]:
|
||||
"""Get a parameter (or nested reference) or err."""
|
||||
ref_name = ref.ref.split("/")[-1]
|
||||
parameters = self._parameters_strict
|
||||
@ -176,13 +184,13 @@ class OpenAPISpec(OpenAPI):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj: dict) -> "OpenAPISpec":
|
||||
def parse_obj(cls, obj: dict) -> OpenAPISpec:
|
||||
try:
|
||||
cls._alert_unsupported_spec(obj)
|
||||
return super().parse_obj(obj)
|
||||
except ValidationError as e:
|
||||
# We are handling possibly misconfigured specs and want to do a best-effort
|
||||
# job to get a reasonable interface out of it.
|
||||
# We are handling possibly misconfigured specs and
|
||||
# want to do a best-effort job to get a reasonable interface out of it.
|
||||
new_obj = copy.deepcopy(obj)
|
||||
for error in e.errors():
|
||||
keys = error["loc"]
|
||||
@ -193,12 +201,12 @@ class OpenAPISpec(OpenAPI):
|
||||
return cls.parse_obj(new_obj)
|
||||
|
||||
@classmethod
|
||||
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
|
||||
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a dict."""
|
||||
return cls.parse_obj(spec_dict)
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str) -> "OpenAPISpec":
|
||||
def from_text(cls, text: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a text."""
|
||||
try:
|
||||
spec_dict = json.loads(text)
|
||||
@ -207,7 +215,7 @@ class OpenAPISpec(OpenAPI):
|
||||
return cls.from_spec_dict(spec_dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
|
||||
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a file path."""
|
||||
path_ = path if isinstance(path, Path) else Path(path)
|
||||
if not path_.exists():
|
||||
@ -216,7 +224,7 @@ class OpenAPISpec(OpenAPI):
|
||||
return cls.from_text(f.read())
|
||||
|
||||
@classmethod
|
||||
def from_url(cls, url: str) -> "OpenAPISpec":
|
||||
def from_url(cls, url: str) -> OpenAPISpec:
|
||||
"""Get an OpenAPI spec from a URL."""
|
||||
response = requests.get(url)
|
||||
return cls.from_text(response.text)
|
||||
@ -275,7 +283,9 @@ class OpenAPISpec(OpenAPI):
|
||||
return request_body
|
||||
|
||||
@staticmethod
|
||||
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
|
||||
def get_cleaned_operation_id(
|
||||
operation: Operation, path: str, method: str
|
||||
) -> str:
|
||||
"""Get a cleaned operation id from an operation id."""
|
||||
operation_id = operation.operationId
|
||||
if operation_id is None:
|
||||
@ -283,3 +293,11 @@ class OpenAPISpec(OpenAPI):
|
||||
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
||||
operation_id = f"{path}_{method}"
|
||||
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|
||||
|
||||
else:
|
||||
|
||||
class OpenAPISpec: # type: ignore[no-redef]
|
||||
"""Shim for pydantic version >=2"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise NotImplementedError("Only supported for pydantic version 1")
|
||||
|
@ -4,6 +4,18 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
# Keep at top of file to ensure that pydantic test can be skipped before
|
||||
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
|
||||
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||
|
||||
if _PYDANTIC_MAJOR_VERSION != 1:
|
||||
pytest.skip(
|
||||
f"Pydantic major version {_PYDANTIC_MAJOR_VERSION} is not supported.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from openapi_schema_pydantic import (
|
||||
|
Loading…
Reference in New Issue
Block a user