core[minor],community[minor]: Upgrade all @root_validator() to @pre_init (#23841)

This PR introduces a @pre_init decorator that's a @root_validator(pre=True) but with all the defaults populated!
This commit is contained in:
Eugene Yurtsev
2024-07-08 16:09:29 -04:00
committed by GitHub
parent f152d6ed3d
commit 2c180d645e
114 changed files with 439 additions and 276 deletions

View File

@@ -16,6 +16,7 @@ from langchain_core.utils.input import (
)
from langchain_core.utils.iter import batch_iterate
from langchain_core.utils.loading import try_load_from_hub
from langchain_core.utils.pydantic import pre_init
from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value
from langchain_core.utils.utils import (
build_extra_kwargs,
@@ -50,6 +51,7 @@ __all__ = [
"stringify_dict",
"comma_list",
"stringify_value",
"pre_init",
"batch_iterate",
"abatch_iterate",
]

View File

@@ -1,5 +1,10 @@
"""Utilities for tests."""
from functools import wraps
from typing import Any, Callable, Dict, Type
from langchain_core.pydantic_v1 import BaseModel, root_validator
def get_pydantic_major_version() -> int:
"""Get the major version of Pydantic."""
@@ -12,3 +17,35 @@ def get_pydantic_major_version() -> int:
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
# How to type hint this?
def pre_init(func: Callable) -> Any:
"""Decorator to run a function before model initialization."""
@root_validator(pre=True)
@wraps(func)
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
"""Decorator to run a function before model initialization."""
# Insert default values
fields = cls.__fields__
for name, field_info in fields.items():
# Check if allow_population_by_field_name is enabled
# If yes, then set the field name to the alias
if hasattr(cls, "Config"):
if hasattr(cls.Config, "allow_population_by_field_name"):
if cls.Config.allow_population_by_field_name:
if field_info.alias in values:
values[name] = values.pop(field_info.alias)
if name not in values or values[name] is None:
if not field_info.required:
if field_info.default_factory is not None:
values[name] = field_info.default_factory()
else:
values[name] = field_info.default
# Call the decorated function
return func(cls, values)
return wrapper

View File

@@ -24,6 +24,7 @@ EXPECTED_ALL = [
"stringify_dict",
"comma_list",
"stringify_value",
"pre_init",
]

View File

@@ -0,0 +1,75 @@
"""Test for some custom pydantic decorators."""
from typing import Any, Dict, Optional
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import pre_init
def test_pre_init_decorator() -> None:
class Foo(BaseModel):
x: int = 5
y: int
@pre_init
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
v["y"] = v["x"] + 1
return v
# Type ignore initialization b/c y is marked as required
foo = Foo() # type: ignore
assert foo.y == 6
foo = Foo(x=10) # type: ignore
assert foo.y == 11
def test_pre_init_decorator_with_more_defaults() -> None:
class Foo(BaseModel):
a: int = 1
b: Optional[int] = None
c: int = Field(default=2)
d: int = Field(default_factory=lambda: 3)
@pre_init
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
assert v["a"] == 1
assert v["b"] is None
assert v["c"] == 2
assert v["d"] == 3
return v
# Try to create an instance of Foo
# nothing is required, but mypy can't track the default for `c`
Foo() # type: ignore
def test_with_aliases() -> None:
class Foo(BaseModel):
x: int = Field(default=1, alias="y")
z: int
class Config:
allow_population_by_field_name = True
@pre_init
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
v["z"] = v["x"]
return v
# Based on defaults
# z is required
foo = Foo() # type: ignore
assert foo.x == 1
assert foo.z == 1
# Based on field name
# z is required
foo = Foo(x=2) # type: ignore
assert foo.x == 2
assert foo.z == 2
# Based on alias
# z is required
foo = Foo(y=2) # type: ignore
assert foo.x == 2
assert foo.z == 2