core[patch]: fix deprecated pydantic code (#26161)

This commit is contained in:
Bagatur
2024-09-06 17:14:17 -04:00
committed by GitHub
parent b2c8f2de4c
commit b2ba4f4072
12 changed files with 118 additions and 55 deletions

View File

@@ -32,7 +32,7 @@ class InMemoryDocumentIndex(DocumentIndex):
for item in items:
if item.id is None:
id_ = str(uuid.uuid4())
item_ = item.copy()
item_ = item.model_copy()
item_.id = id_
else:
item_ = item

View File

@@ -275,10 +275,14 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
else:
fn_name = _result["name"]
_args = _result["arguments"]
if hasattr(self.pydantic_schema, "model_validate_json"):
pydantic_args = self.pydantic_schema[fn_name].model_validate_json(_args) # type: ignore
if isinstance(self.pydantic_schema, dict):
pydantic_schema = self.pydantic_schema[fn_name]
else:
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore
pydantic_schema = self.pydantic_schema
if hasattr(pydantic_schema, "model_validate_json"):
pydantic_args = pydantic_schema.model_validate_json(_args) # type: ignore
else:
pydantic_args = pydantic_schema.parse_raw(_args) # type: ignore
return pydantic_args

View File

@@ -310,7 +310,7 @@ class BasePromptTemplate(
Raises:
NotImplementedError: If the prompt type is not implemented.
"""
prompt_dict = super().dict(**kwargs)
prompt_dict = super().model_dump(**kwargs)
try:
prompt_dict["_type"] = self._prompt_type
except NotImplementedError:

View File

@@ -6,6 +6,7 @@ import ast
import asyncio
import inspect
import textwrap
import warnings
from functools import lru_cache
from inspect import signature
from itertools import groupby
@@ -31,13 +32,14 @@ from typing import (
cast,
)
from pydantic import BaseModel, ConfigDict, RootModel
from pydantic import BaseModel, ConfigDict, PydanticDeprecationWarning, RootModel
from pydantic import create_model as _create_model_base # pydantic :ignore
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)
from pydantic.v1 import BaseModel as BaseModelV1
from typing_extensions import TypeGuard
from langchain_core.runnables.schema import StreamEvent
@@ -754,7 +756,12 @@ def create_base_class(
if default_ is not NO_DEFAULT:
base_class_attributes["root"] = default_
custom_root_type = type(name, (RootModel,), base_class_attributes)
with warnings.catch_warnings():
if isinstance(type_, type) and issubclass(type_, BaseModelV1):
warnings.filterwarnings(
action="ignore", category=PydanticDeprecationWarning
)
custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[BaseModel], custom_root_type)

View File

@@ -476,7 +476,10 @@ class ChildTool(BaseTool):
if isinstance(tool_input, str):
if input_args is not None:
key_ = next(iter(get_fields(input_args).keys()))
input_args.validate({key_: tool_input})
if hasattr(input_args, "model_validate"):
input_args.model_validate({key_: tool_input})
else:
input_args.parse_obj({key_: tool_input})
return tool_input
else:
if input_args is not None:

View File

@@ -283,7 +283,7 @@ class _TracerCore(ABC):
def _complete_llm_run(self, response: LLMResult, run_id: UUID) -> Run:
llm_run = self._get_run(run_id, run_type={"llm", "chat_model"})
llm_run.outputs = response.dict()
llm_run.outputs = response.model_dump()
for i, generations in enumerate(response.generations):
for j, generation in enumerate(generations):
output_generation = llm_run.outputs["generations"][i][j]

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -10,6 +11,7 @@ from uuid import UUID
from langsmith import Client
from langsmith import utils as ls_utils
from pydantic import PydanticDeprecationWarning
from tenacity import (
Retrying,
retry_if_exception_type,
@@ -69,11 +71,16 @@ def _get_executor() -> ThreadPoolExecutor:
def _run_to_dict(run: Run) -> dict:
return {
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
"inputs": run.inputs.copy() if run.inputs is not None else None,
"outputs": run.outputs.copy() if run.outputs is not None else None,
}
# TODO: Update once langsmith moves to Pydantic V2 and we can swap run.dict for
# run.model_dump
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
return {
**run.dict(exclude={"child_runs", "inputs", "outputs"}),
"inputs": run.inputs.copy() if run.inputs is not None else None,
"outputs": run.outputs.copy() if run.outputs is not None else None,
}
class LangChainTracer(BaseTracer):
@@ -152,7 +159,12 @@ class LangChainTracer(BaseTracer):
return chat_model_run
def _persist_run(self, run: Run) -> None:
run_ = run.copy()
# TODO: Update once langsmith moves to Pydantic V2 and we can swap run.copy for
# run.model_copy
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
run_ = run.copy()
run_.reference_example_id = self.example_id
self.latest_run = run_

View File

@@ -9,6 +9,7 @@ from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from pydantic import PydanticDeprecationWarning
from langchain_core._api import deprecated
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
@@ -142,9 +143,14 @@ class Run(BaseRunV2):
return values
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
# TODO: Update once langsmith moves to Pydantic V2 and we can swap Run.model_rebuild
# for Run.update_forward_refs
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=PydanticDeprecationWarning)
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
__all__ = [
"BaseRun",

View File

@@ -4,11 +4,12 @@ from __future__ import annotations
import inspect
import textwrap
import warnings
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload
import pydantic
from pydantic import BaseModel, root_validator
from pydantic import BaseModel, PydanticDeprecationWarning, root_validator
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import core_schema
@@ -134,42 +135,45 @@ def pre_init(func: Callable) -> Any:
Any: The decorated function.
"""
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization.
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", category=PydanticDeprecationWarning)
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if hasattr(cls, "Config"):
if hasattr(cls.Config, "allow_population_by_field_name"):
if cls.Config.allow_population_by_field_name:
Args:
cls (Type[BaseModel]): The model class.
values (Dict[str, Any]): The values to initialize the model with.
Returns:
Dict[str, Any]: The values to initialize the model with.
"""
# Insert default values
fields = cls.model_fields
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if hasattr(cls, "Config"):
if hasattr(cls.Config, "allow_population_by_field_name"):
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if hasattr(cls, "model_config"):
if cls.model_config.get("populate_by_name"):
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
if name not in values or values[name] is None:
if not field_info.is_required():
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
# Call the decorated function
return func(cls, values)
# Call the decorated function
return func(cls, values)
return wrapper

View File

@@ -56,6 +56,9 @@ markers = [
"compile: mark placeholder test used to compile integration tests without running them",
]
asyncio_mode = "auto"
filterwarnings = [
"ignore::langchain_core._api.beta_decorator.LangChainBetaWarning",
]
[tool.poetry.group.lint]
optional = true

View File

@@ -293,7 +293,7 @@ def test_structured_tool_types_parsed() -> None:
assert isinstance(structured_tool, StructuredTool)
args = {
"some_enum": SomeEnum.A.value,
"some_base_model": SomeBaseModel(foo="bar").dict(),
"some_base_model": SomeBaseModel(foo="bar").model_dump(),
}
result = structured_tool.run(json.loads(json.dumps(args)))
expected = {
@@ -1600,7 +1600,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
name="some_tool", description="some description", args_schema=pydantic_model
)
assert tool.get_input_schema().schema() == {
input_schema = tool.get_input_schema()
input_json_schema = (
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
assert input_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
@@ -1610,7 +1616,13 @@ def test_args_schema_as_pydantic(pydantic_model: Any) -> None:
"type": "object",
}
assert tool.tool_call_schema.schema() == {
tool_schema = tool.tool_call_schema
tool_json_schema = (
tool_schema.model_json_schema()
if hasattr(tool_schema, "model_json_schema")
else tool_schema.schema()
)
assert tool_json_schema == {
"description": "some description",
"properties": {
"a": {"title": "A", "type": "integer"},
@@ -1684,7 +1696,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"
assert foo_tool.args_schema.schema() == {
args_schema = foo_tool.args_schema
args_json_schema = (
args_schema.model_json_schema()
if hasattr(args_schema, "model_json_schema")
else args_schema.schema()
)
assert args_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
@@ -1694,7 +1712,13 @@ def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -
"type": "object",
}
assert foo_tool.get_input_schema().schema() == {
input_schema = foo_tool.get_input_schema()
input_json_schema = (
input_schema.model_json_schema()
if hasattr(input_schema, "model_json_schema")
else input_schema.schema()
)
assert input_json_schema == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},

View File

@@ -171,7 +171,7 @@ def test_fields_pydantic_v1() -> None:
x: int
fields = get_fields(Foo)
assert fields == {"x": Foo.__fields__["x"]} # type: ignore[index]
assert fields == {"x": Foo.model_fields["x"]} # type: ignore[index]
@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="Only tests Pydantic v2")