mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
Wrap OpenAPI features in conditionals for pydantic v2 compatibility (#9205)
Wrap OpenAPI in conditionals for pydantic v2 compatibility.
This commit is contained in:
parent
89be10f6b4
commit
4f1feaca83
@ -27,6 +27,12 @@ if "pydantic_v1" not in sys.modules:
|
|||||||
# and may run prior to langchain core package.
|
# and may run prior to langchain core package.
|
||||||
sys.modules["pydantic_v1"] = pydantic_v1
|
sys.modules["pydantic_v1"] = pydantic_v1
|
||||||
|
|
||||||
|
try:
|
||||||
|
_PYDANTIC_MAJOR_VERSION: int = int(metadata.version("pydantic").split(".")[0])
|
||||||
|
except metadata.PackageNotFoundError:
|
||||||
|
_PYDANTIC_MAJOR_VERSION = 0
|
||||||
|
|
||||||
|
|
||||||
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
from langchain.agents import MRKLChain, ReActChain, SelfAskWithSearchChain
|
||||||
from langchain.cache import BaseCache
|
from langchain.cache import BaseCache
|
||||||
from langchain.chains import (
|
from langchain.chains import (
|
||||||
|
@ -3,9 +3,9 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
from openapi_schema_pydantic import MediaType, Parameter, Reference, RequestBody, Schema
|
|
||||||
from pydantic_v1 import BaseModel, Field
|
from pydantic_v1 import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||||
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
|
from langchain.tools.openapi.utils.openapi_utils import HTTPVerb, OpenAPISpec
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -85,14 +85,25 @@ class APIPropertyBase(BaseModel):
|
|||||||
"""The description of the property."""
|
"""The description of the property."""
|
||||||
|
|
||||||
|
|
||||||
class APIProperty(APIPropertyBase):
|
if _PYDANTIC_MAJOR_VERSION == 1:
|
||||||
|
from openapi_schema_pydantic import (
|
||||||
|
MediaType,
|
||||||
|
Parameter,
|
||||||
|
Reference,
|
||||||
|
RequestBody,
|
||||||
|
Schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
class APIProperty(APIPropertyBase):
|
||||||
"""A model for a property in the query, path, header, or cookie params."""
|
"""A model for a property in the query, path, header, or cookie params."""
|
||||||
|
|
||||||
location: APIPropertyLocation = Field(alias="location")
|
location: APIPropertyLocation = Field(alias="location")
|
||||||
"""The path/how it's being passed to the endpoint."""
|
"""The path/how it's being passed to the endpoint."""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cast_schema_list_type(schema: Schema) -> Optional[Union[str, Tuple[str, ...]]]:
|
def _cast_schema_list_type(
|
||||||
|
schema: Schema,
|
||||||
|
) -> Optional[Union[str, Tuple[str, ...]]]:
|
||||||
type_ = schema.type
|
type_ = schema.type
|
||||||
if not isinstance(type_, list):
|
if not isinstance(type_, list):
|
||||||
return type_
|
return type_
|
||||||
@ -125,7 +136,9 @@ class APIProperty(APIPropertyBase):
|
|||||||
return schema_type
|
return schema_type
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_schema_type(parameter: Parameter, schema: Optional[Schema]) -> SCHEMA_TYPE:
|
def _get_schema_type(
|
||||||
|
parameter: Parameter, schema: Optional[Schema]
|
||||||
|
) -> SCHEMA_TYPE:
|
||||||
if schema is None:
|
if schema is None:
|
||||||
return None
|
return None
|
||||||
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
|
schema_type: SCHEMA_TYPE = APIProperty._cast_schema_list_type(schema)
|
||||||
@ -136,7 +149,9 @@ class APIProperty(APIPropertyBase):
|
|||||||
raise NotImplementedError("Objects not yet supported")
|
raise NotImplementedError("Objects not yet supported")
|
||||||
elif schema_type in PRIMITIVE_TYPES:
|
elif schema_type in PRIMITIVE_TYPES:
|
||||||
if schema.enum:
|
if schema.enum:
|
||||||
schema_type = APIProperty._get_schema_type_for_enum(parameter, schema)
|
schema_type = APIProperty._get_schema_type_for_enum(
|
||||||
|
parameter, schema
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Directly use the primitive type
|
# Directly use the primitive type
|
||||||
pass
|
pass
|
||||||
@ -181,7 +196,9 @@ class APIProperty(APIPropertyBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_parameter(cls, parameter: Parameter, spec: OpenAPISpec) -> "APIProperty":
|
def from_parameter(
|
||||||
|
cls, parameter: Parameter, spec: OpenAPISpec
|
||||||
|
) -> "APIProperty":
|
||||||
"""Instantiate from an OpenAPI Parameter."""
|
"""Instantiate from an OpenAPI Parameter."""
|
||||||
location = APIPropertyLocation.from_str(parameter.param_in)
|
location = APIPropertyLocation.from_str(parameter.param_in)
|
||||||
cls._validate_location(
|
cls._validate_location(
|
||||||
@ -201,8 +218,7 @@ class APIProperty(APIPropertyBase):
|
|||||||
type=schema_type,
|
type=schema_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class APIRequestBodyProperty(APIPropertyBase):
|
||||||
class APIRequestBodyProperty(APIPropertyBase):
|
|
||||||
"""A model for a request body property."""
|
"""A model for a request body property."""
|
||||||
|
|
||||||
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
|
properties: List["APIRequestBodyProperty"] = Field(alias="properties")
|
||||||
@ -245,7 +261,11 @@ class APIRequestBodyProperty(APIPropertyBase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process_array_schema(
|
def _process_array_schema(
|
||||||
cls, schema: Schema, name: str, spec: OpenAPISpec, references_used: List[str]
|
cls,
|
||||||
|
schema: Schema,
|
||||||
|
name: str,
|
||||||
|
spec: OpenAPISpec,
|
||||||
|
references_used: List[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
items = schema.items
|
items = schema.items
|
||||||
if items is not None:
|
if items is not None:
|
||||||
@ -292,7 +312,9 @@ class APIRequestBodyProperty(APIPropertyBase):
|
|||||||
schema, spec, references_used
|
schema, spec, references_used
|
||||||
)
|
)
|
||||||
elif schema_type == "array":
|
elif schema_type == "array":
|
||||||
schema_type = cls._process_array_schema(schema, name, spec, references_used)
|
schema_type = cls._process_array_schema(
|
||||||
|
schema, name, spec, references_used
|
||||||
|
)
|
||||||
elif schema_type in PRIMITIVE_TYPES:
|
elif schema_type in PRIMITIVE_TYPES:
|
||||||
# Use the primitive type directly
|
# Use the primitive type directly
|
||||||
pass
|
pass
|
||||||
@ -312,8 +334,8 @@ class APIRequestBodyProperty(APIPropertyBase):
|
|||||||
references_used=references_used,
|
references_used=references_used,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# class APIRequestBodyProperty(APIPropertyBase):
|
||||||
class APIRequestBody(BaseModel):
|
class APIRequestBody(BaseModel):
|
||||||
"""A model for a request body."""
|
"""A model for a request body."""
|
||||||
|
|
||||||
description: Optional[str] = Field(alias="description")
|
description: Optional[str] = Field(alias="description")
|
||||||
@ -392,8 +414,9 @@ class APIRequestBody(BaseModel):
|
|||||||
media_type=media_type,
|
media_type=media_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# class APIRequestBodyProperty(APIPropertyBase):
|
||||||
class APIOperation(BaseModel):
|
# class APIRequestBody(BaseModel):
|
||||||
|
class APIOperation(BaseModel):
|
||||||
"""A model for a single API operation."""
|
"""A model for a single API operation."""
|
||||||
|
|
||||||
operation_id: str = Field(alias="operation_id")
|
operation_id: str = Field(alias="operation_id")
|
||||||
@ -527,7 +550,8 @@ class APIOperation(BaseModel):
|
|||||||
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
|
prop_type = f"{{\n{nested_props}\n{' ' * indent}}}"
|
||||||
|
|
||||||
formatted_props.append(
|
formatted_props.append(
|
||||||
f"{prop_desc}\n{' ' * indent}{prop_name}{prop_required}: {prop_type},"
|
f"{prop_desc}\n{' ' * indent}{prop_name}"
|
||||||
|
f"{prop_required}: {prop_type},"
|
||||||
)
|
)
|
||||||
|
|
||||||
return "\n".join(formatted_props)
|
return "\n".join(formatted_props)
|
||||||
@ -548,16 +572,18 @@ class APIOperation(BaseModel):
|
|||||||
prop_type = self.ts_type_from_python(prop.type)
|
prop_type = self.ts_type_from_python(prop.type)
|
||||||
prop_required = "" if prop.required else "?"
|
prop_required = "" if prop.required else "?"
|
||||||
prop_desc = f"/* {prop.description} */" if prop.description else ""
|
prop_desc = f"/* {prop.description} */" if prop.description else ""
|
||||||
params.append(f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},")
|
params.append(
|
||||||
|
f"{prop_desc}\n\t\t{prop_name}{prop_required}: {prop_type},"
|
||||||
|
)
|
||||||
|
|
||||||
formatted_params = "\n".join(params).strip()
|
formatted_params = "\n".join(params).strip()
|
||||||
description_str = f"/* {self.description} */" if self.description else ""
|
description_str = f"/* {self.description} */" if self.description else ""
|
||||||
typescript_definition = f"""
|
typescript_definition = f"""
|
||||||
{description_str}
|
{description_str}
|
||||||
type {operation_name} = (_: {{
|
type {operation_name} = (_: {{
|
||||||
{formatted_params}
|
{formatted_params}
|
||||||
}}) => any;
|
}}) => any;
|
||||||
"""
|
"""
|
||||||
return typescript_definition.strip()
|
return typescript_definition.strip()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -581,3 +607,21 @@ type {operation_name} = (_: {{
|
|||||||
if self.request_body is None:
|
if self.request_body is None:
|
||||||
return []
|
return []
|
||||||
return [prop.name for prop in self.request_body.properties]
|
return [prop.name for prop in self.request_body.properties]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
class APIProperty(APIPropertyBase): # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError("Only supported for pydantic v1")
|
||||||
|
|
||||||
|
class APIRequestBodyProperty(APIPropertyBase): # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError("Only supported for pydantic v1")
|
||||||
|
|
||||||
|
class APIRequestBody(BaseModel): # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError("Only supported for pydantic v1")
|
||||||
|
|
||||||
|
class APIOperation(BaseModel): # type: ignore[no-redef]
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
raise NotImplementedError("Only supported for pydantic v1")
|
||||||
|
@ -1,2 +1,4 @@
|
|||||||
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
|
"""Utility functions for parsing an OpenAPI spec. Kept for backwards compat."""
|
||||||
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec # noqa: F401
|
from langchain.utilities.openapi import HTTPVerb, OpenAPISpec
|
||||||
|
|
||||||
|
__all__ = ["HTTPVerb", "OpenAPISpec"]
|
||||||
|
@ -2,9 +2,8 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic_v1 import Extra, root_validator
|
from pydantic_v1 import BaseModel, Extra, root_validator
|
||||||
|
|
||||||
from langchain.tools.base import BaseModel
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
"""Utility functions for parsing an OpenAPI spec."""
|
"""Utility functions for parsing an OpenAPI spec."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -9,19 +11,10 @@ from typing import Dict, List, Optional, Union
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
from openapi_schema_pydantic import (
|
|
||||||
Components,
|
|
||||||
OpenAPI,
|
|
||||||
Operation,
|
|
||||||
Parameter,
|
|
||||||
PathItem,
|
|
||||||
Paths,
|
|
||||||
Reference,
|
|
||||||
RequestBody,
|
|
||||||
Schema,
|
|
||||||
)
|
|
||||||
from pydantic_v1 import ValidationError
|
from pydantic_v1 import ValidationError
|
||||||
|
|
||||||
|
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -38,7 +31,7 @@ class HTTPVerb(str, Enum):
|
|||||||
TRACE = "trace"
|
TRACE = "trace"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_str(cls, verb: str) -> "HTTPVerb":
|
def from_str(cls, verb: str) -> HTTPVerb:
|
||||||
"""Parse an HTTP verb."""
|
"""Parse an HTTP verb."""
|
||||||
try:
|
try:
|
||||||
return cls(verb)
|
return cls(verb)
|
||||||
@ -46,8 +39,21 @@ class HTTPVerb(str, Enum):
|
|||||||
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
|
raise ValueError(f"Invalid HTTP verb. Valid values are {cls.__members__}")
|
||||||
|
|
||||||
|
|
||||||
class OpenAPISpec(OpenAPI):
|
if _PYDANTIC_MAJOR_VERSION == 1:
|
||||||
"""OpenAPI Model that removes misformatted parts of the spec."""
|
from openapi_schema_pydantic import (
|
||||||
|
Components,
|
||||||
|
OpenAPI,
|
||||||
|
Operation,
|
||||||
|
Parameter,
|
||||||
|
PathItem,
|
||||||
|
Paths,
|
||||||
|
Reference,
|
||||||
|
RequestBody,
|
||||||
|
Schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
class OpenAPISpec(OpenAPI):
|
||||||
|
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _paths_strict(self) -> Paths:
|
def _paths_strict(self) -> Paths:
|
||||||
@ -92,7 +98,9 @@ class OpenAPISpec(OpenAPI):
|
|||||||
raise ValueError("No request body found in spec. ")
|
raise ValueError("No request body found in spec. ")
|
||||||
return request_bodies
|
return request_bodies
|
||||||
|
|
||||||
def _get_referenced_parameter(self, ref: Reference) -> Union[Parameter, Reference]:
|
def _get_referenced_parameter(
|
||||||
|
self, ref: Reference
|
||||||
|
) -> Union[Parameter, Reference]:
|
||||||
"""Get a parameter (or nested reference) or err."""
|
"""Get a parameter (or nested reference) or err."""
|
||||||
ref_name = ref.ref.split("/")[-1]
|
ref_name = ref.ref.split("/")[-1]
|
||||||
parameters = self._parameters_strict
|
parameters = self._parameters_strict
|
||||||
@ -176,13 +184,13 @@ class OpenAPISpec(OpenAPI):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def parse_obj(cls, obj: dict) -> "OpenAPISpec":
|
def parse_obj(cls, obj: dict) -> OpenAPISpec:
|
||||||
try:
|
try:
|
||||||
cls._alert_unsupported_spec(obj)
|
cls._alert_unsupported_spec(obj)
|
||||||
return super().parse_obj(obj)
|
return super().parse_obj(obj)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
# We are handling possibly misconfigured specs and want to do a best-effort
|
# We are handling possibly misconfigured specs and
|
||||||
# job to get a reasonable interface out of it.
|
# want to do a best-effort job to get a reasonable interface out of it.
|
||||||
new_obj = copy.deepcopy(obj)
|
new_obj = copy.deepcopy(obj)
|
||||||
for error in e.errors():
|
for error in e.errors():
|
||||||
keys = error["loc"]
|
keys = error["loc"]
|
||||||
@ -193,12 +201,12 @@ class OpenAPISpec(OpenAPI):
|
|||||||
return cls.parse_obj(new_obj)
|
return cls.parse_obj(new_obj)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_spec_dict(cls, spec_dict: dict) -> "OpenAPISpec":
|
def from_spec_dict(cls, spec_dict: dict) -> OpenAPISpec:
|
||||||
"""Get an OpenAPI spec from a dict."""
|
"""Get an OpenAPI spec from a dict."""
|
||||||
return cls.parse_obj(spec_dict)
|
return cls.parse_obj(spec_dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_text(cls, text: str) -> "OpenAPISpec":
|
def from_text(cls, text: str) -> OpenAPISpec:
|
||||||
"""Get an OpenAPI spec from a text."""
|
"""Get an OpenAPI spec from a text."""
|
||||||
try:
|
try:
|
||||||
spec_dict = json.loads(text)
|
spec_dict = json.loads(text)
|
||||||
@ -207,7 +215,7 @@ class OpenAPISpec(OpenAPI):
|
|||||||
return cls.from_spec_dict(spec_dict)
|
return cls.from_spec_dict(spec_dict)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, path: Union[str, Path]) -> "OpenAPISpec":
|
def from_file(cls, path: Union[str, Path]) -> OpenAPISpec:
|
||||||
"""Get an OpenAPI spec from a file path."""
|
"""Get an OpenAPI spec from a file path."""
|
||||||
path_ = path if isinstance(path, Path) else Path(path)
|
path_ = path if isinstance(path, Path) else Path(path)
|
||||||
if not path_.exists():
|
if not path_.exists():
|
||||||
@ -216,7 +224,7 @@ class OpenAPISpec(OpenAPI):
|
|||||||
return cls.from_text(f.read())
|
return cls.from_text(f.read())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_url(cls, url: str) -> "OpenAPISpec":
|
def from_url(cls, url: str) -> OpenAPISpec:
|
||||||
"""Get an OpenAPI spec from a URL."""
|
"""Get an OpenAPI spec from a URL."""
|
||||||
response = requests.get(url)
|
response = requests.get(url)
|
||||||
return cls.from_text(response.text)
|
return cls.from_text(response.text)
|
||||||
@ -275,7 +283,9 @@ class OpenAPISpec(OpenAPI):
|
|||||||
return request_body
|
return request_body
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cleaned_operation_id(operation: Operation, path: str, method: str) -> str:
|
def get_cleaned_operation_id(
|
||||||
|
operation: Operation, path: str, method: str
|
||||||
|
) -> str:
|
||||||
"""Get a cleaned operation id from an operation id."""
|
"""Get a cleaned operation id from an operation id."""
|
||||||
operation_id = operation.operationId
|
operation_id = operation.operationId
|
||||||
if operation_id is None:
|
if operation_id is None:
|
||||||
@ -283,3 +293,11 @@ class OpenAPISpec(OpenAPI):
|
|||||||
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
path = re.sub(r"[^a-zA-Z0-9]", "_", path.lstrip("/"))
|
||||||
operation_id = f"{path}_{method}"
|
operation_id = f"{path}_{method}"
|
||||||
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|
return operation_id.replace("-", "_").replace(".", "_").replace("/", "_")
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
class OpenAPISpec: # type: ignore[no-redef]
|
||||||
|
"""Shim for pydantic version >=2"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
raise NotImplementedError("Only supported for pydantic version 1")
|
||||||
|
@ -4,6 +4,18 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, List, Tuple
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Keep at top of file to ensure that pydantic test can be skipped before
|
||||||
|
# pydantic v1 related imports are attempted by openapi_schema_pydantic.
|
||||||
|
from langchain import _PYDANTIC_MAJOR_VERSION
|
||||||
|
|
||||||
|
if _PYDANTIC_MAJOR_VERSION != 1:
|
||||||
|
pytest.skip(
|
||||||
|
f"Pydantic major version {_PYDANTIC_MAJOR_VERSION} is not supported.",
|
||||||
|
allow_module_level=True,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import yaml
|
import yaml
|
||||||
from openapi_schema_pydantic import (
|
from openapi_schema_pydantic import (
|
||||||
|
Loading…
Reference in New Issue
Block a user