mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 21:47:12 +00:00
core[minor],community[minor]: Upgrade all @root_validator() to @pre_init (#23841)
This PR introduces a @pre_init decorator that's a @root_validator(pre=True) but with all the defaults populated!
This commit is contained in:
@@ -16,6 +16,7 @@ from langchain_core.utils.input import (
|
||||
)
|
||||
from langchain_core.utils.iter import batch_iterate
|
||||
from langchain_core.utils.loading import try_load_from_hub
|
||||
from langchain_core.utils.pydantic import pre_init
|
||||
from langchain_core.utils.strings import comma_list, stringify_dict, stringify_value
|
||||
from langchain_core.utils.utils import (
|
||||
build_extra_kwargs,
|
||||
@@ -50,6 +51,7 @@ __all__ = [
|
||||
"stringify_dict",
|
||||
"comma_list",
|
||||
"stringify_value",
|
||||
"pre_init",
|
||||
"batch_iterate",
|
||||
"abatch_iterate",
|
||||
]
|
||||
|
@@ -1,5 +1,10 @@
|
||||
"""Utilities for tests."""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, Type
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||||
|
||||
|
||||
def get_pydantic_major_version() -> int:
|
||||
"""Get the major version of Pydantic."""
|
||||
@@ -12,3 +17,35 @@ def get_pydantic_major_version() -> int:
|
||||
|
||||
|
||||
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()
|
||||
|
||||
|
||||
# How to type hint this?
|
||||
def pre_init(func: Callable) -> Any:
|
||||
"""Decorator to run a function before model initialization."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
@wraps(func)
|
||||
def wrapper(cls: Type[BaseModel], values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Decorator to run a function before model initialization."""
|
||||
# Insert default values
|
||||
fields = cls.__fields__
|
||||
for name, field_info in fields.items():
|
||||
# Check if allow_population_by_field_name is enabled
|
||||
# If yes, then set the field name to the alias
|
||||
if hasattr(cls, "Config"):
|
||||
if hasattr(cls.Config, "allow_population_by_field_name"):
|
||||
if cls.Config.allow_population_by_field_name:
|
||||
if field_info.alias in values:
|
||||
values[name] = values.pop(field_info.alias)
|
||||
|
||||
if name not in values or values[name] is None:
|
||||
if not field_info.required:
|
||||
if field_info.default_factory is not None:
|
||||
values[name] = field_info.default_factory()
|
||||
else:
|
||||
values[name] = field_info.default
|
||||
|
||||
# Call the decorated function
|
||||
return func(cls, values)
|
||||
|
||||
return wrapper
|
||||
|
@@ -24,6 +24,7 @@ EXPECTED_ALL = [
|
||||
"stringify_dict",
|
||||
"comma_list",
|
||||
"stringify_value",
|
||||
"pre_init",
|
||||
]
|
||||
|
||||
|
||||
|
75
libs/core/tests/unit_tests/utils/test_pydantic.py
Normal file
75
libs/core/tests/unit_tests/utils/test_pydantic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Test for some custom pydantic decorators."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.utils.pydantic import pre_init
|
||||
|
||||
|
||||
def test_pre_init_decorator() -> None:
|
||||
class Foo(BaseModel):
|
||||
x: int = 5
|
||||
y: int
|
||||
|
||||
@pre_init
|
||||
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
v["y"] = v["x"] + 1
|
||||
return v
|
||||
|
||||
# Type ignore initialization b/c y is marked as required
|
||||
foo = Foo() # type: ignore
|
||||
assert foo.y == 6
|
||||
foo = Foo(x=10) # type: ignore
|
||||
assert foo.y == 11
|
||||
|
||||
|
||||
def test_pre_init_decorator_with_more_defaults() -> None:
|
||||
class Foo(BaseModel):
|
||||
a: int = 1
|
||||
b: Optional[int] = None
|
||||
c: int = Field(default=2)
|
||||
d: int = Field(default_factory=lambda: 3)
|
||||
|
||||
@pre_init
|
||||
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
assert v["a"] == 1
|
||||
assert v["b"] is None
|
||||
assert v["c"] == 2
|
||||
assert v["d"] == 3
|
||||
return v
|
||||
|
||||
# Try to create an instance of Foo
|
||||
# nothing is required, but mypy can't track the default for `c`
|
||||
Foo() # type: ignore
|
||||
|
||||
|
||||
def test_with_aliases() -> None:
|
||||
class Foo(BaseModel):
|
||||
x: int = Field(default=1, alias="y")
|
||||
z: int
|
||||
|
||||
class Config:
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@pre_init
|
||||
def validator(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
v["z"] = v["x"]
|
||||
return v
|
||||
|
||||
# Based on defaults
|
||||
# z is required
|
||||
foo = Foo() # type: ignore
|
||||
assert foo.x == 1
|
||||
assert foo.z == 1
|
||||
|
||||
# Based on field name
|
||||
# z is required
|
||||
foo = Foo(x=2) # type: ignore
|
||||
assert foo.x == 2
|
||||
assert foo.z == 2
|
||||
|
||||
# Based on alias
|
||||
# z is required
|
||||
foo = Foo(y=2) # type: ignore
|
||||
assert foo.x == 2
|
||||
assert foo.z == 2
|
Reference in New Issue
Block a user