diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index b8473fcd09b..4ea77b65ffa 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -29,7 +29,7 @@ from langchain_core.prompt_values import ( ) from langchain_core.runnables import RunnableConfig, RunnableSerializable from langchain_core.runnables.config import ensure_config -from langchain_core.runnables.utils import create_model +from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: from langchain_core.documents import Document @@ -125,8 +125,9 @@ class BasePromptTemplate( optional_input_variables = { k: (self.input_types.get(k, str), None) for k in self.optional_variables } - return create_model( - "PromptInput", **{**required_input_variables, **optional_input_variables} + return create_model_v2( + "PromptInput", + field_definitions={**required_input_variables, **optional_input_variables}, ) def _validate_input(self, inner_input: Any) -> Dict: diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 3ea4325f3c7..2759c4010a9 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -71,7 +71,6 @@ from langchain_core.runnables.utils import ( accepts_config, accepts_run_manager, asyncio_accepts_context, - create_model, gather_with_concurrency, get_function_first_arg_dict_keys, get_function_nonlocals, @@ -83,6 +82,7 @@ from langchain_core.runnables.utils import ( ) from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.iter import safetee +from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -345,9 +345,9 @@ class Runnable(Generic[Input, Output], ABC): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( self.get_name("Input"), - __root__=root_type, + root=root_type, # create model needs access to appropriate type annotations to be # able to construct the pydantic model. # When we create the model, we pass information about the namespace @@ -355,7 +355,7 @@ class Runnable(Generic[Input, Output], ABC): # be resolved correctly as well. # self.__class__.__module__ handles the case when the Runnable is # being sub-classed in a different module. - __module_name=self.__class__.__module__, + module_name=self.__class__.__module__, ) def get_input_jsonschema( @@ -413,9 +413,9 @@ class Runnable(Generic[Input, Output], ABC): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( self.get_name("Output"), - __root__=root_type, + root=root_type, # create model needs access to appropriate type annotations to be # able to construct the pydantic model. # When we create the model, we pass information about the namespace @@ -423,7 +423,7 @@ class Runnable(Generic[Input, Output], ABC): # be resolved correctly as well. # self.__class__.__module__ handles the case when the Runnable is # being sub-classed in a different module. - __module_name=self.__class__.__module__, + module_name=self.__class__.__module__, ) def get_output_jsonschema( @@ -477,9 +477,9 @@ class Runnable(Generic[Input, Output], ABC): include = include or [] config_specs = self.config_specs configurable = ( - create_model( # type: ignore[call-overload] + create_model_v2( # type: ignore[call-overload] "Configurable", - **{ + field_definitions={ spec.id: ( spec.annotation, Field( @@ -502,8 +502,8 @@ class Runnable(Generic[Input, Output], ABC): if field_name in [i for i in include if i != "configurable"] }, } - model = create_model( # type: ignore[call-overload] - self.get_name("Config"), **all_fields + model = create_model_v2( # type: ignore[call-overload] + self.get_name("Config"), field_definitions=all_fields ) return model @@ -530,14 +530,14 @@ class Runnable(Generic[Input, Output], ABC): try: input_node = graph.add_node(self.get_input_schema(config)) except TypeError: - input_node = graph.add_node(create_model(self.get_name("Input"))) + input_node = graph.add_node(create_model_v2(self.get_name("Input"))) runnable_node = graph.add_node( self, metadata=config.get("metadata") if config else None ) try: output_node = graph.add_node(self.get_output_schema(config)) except TypeError: - output_node = graph.add_node(create_model(self.get_name("Output"))) + output_node = graph.add_node(create_model_v2(self.get_name("Output"))) graph.add_edge(input_node, runnable_node) graph.add_edge(runnable_node, output_node) return graph @@ -2583,9 +2583,9 @@ def _seq_input_schema( next_input_schema = _seq_input_schema(steps[1:], config) if not issubclass(next_input_schema, RootModel): # it's a dict as expected - return create_model( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] "RunnableSequenceInput", - **{ + field_definitions={ k: (v.annotation, v.default) for k, v in next_input_schema.model_fields.items() if k not in first.mapper.steps__ @@ -2610,9 +2610,9 @@ def _seq_output_schema( prev_output_schema = _seq_output_schema(steps[:-1], config) if not issubclass(prev_output_schema, RootModel): # it's a dict as expected - return create_model( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] "RunnableSequenceOutput", - **{ + field_definitions={ **{ k: (v.annotation, v.default) for k, v in prev_output_schema.model_fields.items() @@ -2628,9 +2628,9 @@ def _seq_output_schema( if not issubclass(prev_output_schema, RootModel): # it's a dict as expected if isinstance(last.keys, list): - return create_model( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] "RunnableSequenceOutput", - **{ + field_definitions={ k: (v.annotation, v.default) for k, v in prev_output_schema.model_fields.items() if k in last.keys @@ -2638,9 +2638,8 @@ def _seq_output_schema( ) else: field = prev_output_schema.model_fields[last.keys] - return create_model( # type: ignore[call-overload] - "RunnableSequenceOutput", - __root__=(field.annotation, field.default), + return create_model_v2( # type: ignore[call-overload] + "RunnableSequenceOutput", root=(field.annotation, field.default) ) return last.get_output_schema(config) @@ -3582,9 +3581,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): for s in self.steps__.values() ): # This is correct, but pydantic typings/mypy don't think so. - return create_model( # type: ignore[call-overload] + return create_model_v2( # type: ignore[call-overload] self.get_name("Input"), - **{ + field_definitions={ k: (v.annotation, v.default) for step in self.steps__.values() for k, v in step.get_input_schema(config).model_fields.items() @@ -3606,7 +3605,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): The output schema of the Runnable. """ fields = {k: (v.OutputType, ...) for k, v in self.steps__.items()} - return create_model(self.get_name("Output"), **fields) + return create_model_v2(self.get_name("Output"), field_definitions=fields) @property def config_specs(self) -> List[ConfigurableFieldSpec]: @@ -4076,13 +4075,13 @@ class RunnableGenerator(Runnable[Input, Output]): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( self.get_name("Input"), - __root__=root_type, + root=root_type, # To create the schema, we need to provide the module # where the underlying function is defined. # This allows pydantic to resolve type annotations appropriately. - __module_name=module, + module_name=module, ) @property @@ -4111,13 +4110,13 @@ class RunnableGenerator(Runnable[Input, Output]): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( self.get_name("Output"), - __root__=root_type, + root=root_type, # To create the schema, we need to provide the module # where the underlying function is defined. # This allows pydantic to resolve type annotations appropriately. - __module_name=module, + module_name=module, ) def __eq__(self, other: Any) -> bool: @@ -4366,25 +4365,25 @@ class RunnableLambda(Runnable[Input, Output]): ): fields = {item[1:-1]: (Any, ...) for item in items} # It's a dict, lol - return create_model(self.get_name("Input"), **fields) + return create_model_v2(self.get_name("Input"), field_definitions=fields) else: module = getattr(func, "__module__", None) - return create_model( + return create_model_v2( self.get_name("Input"), - __root__=List[Any], + root=List[Any], # To create the schema, we need to provide the module # where the underlying function is defined. # This allows pydantic to resolve type annotations appropriately. - __module_name=module, + module_name=module, ) if self.InputType != Any: return super().get_input_schema(config) if dict_keys := get_function_first_arg_dict_keys(func): - return create_model( + return create_model_v2( self.get_name("Input"), - **{key: (Any, ...) for key in dict_keys}, # type: ignore + field_definitions={key: (Any, ...) for key in dict_keys}, ) return super().get_input_schema(config) @@ -4425,13 +4424,13 @@ class RunnableLambda(Runnable[Input, Output]): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( self.get_name("Output"), - __root__=root_type, + root=root_type, # To create the schema, we need to provide the module # where the underlying function is defined. # This allows pydantic to resolve type annotations appropriately. - __module_name=module, + module_name=module, ) @property @@ -4945,9 +4944,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: - return create_model( + return create_model_v2( self.get_name("Input"), - __root__=( + root=( List[self.bound.get_input_schema(config)], # type: ignore None, ), @@ -4958,7 +4957,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): # be resolved correctly as well. # self.__class__.__module__ handles the case when the Runnable is # being sub-classed in a different module. - __module_name=self.__class__.__module__, + module_name=self.__class__.__module__, ) @property @@ -4969,9 +4968,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: schema = self.bound.get_output_schema(config) - return create_model( + return create_model_v2( self.get_name("Output"), - __root__=List[schema], # type: ignore[valid-type] + root=List[schema], # type: ignore[valid-type] # create model needs access to appropriate type annotations to be # able to construct the pydantic model. # When we create the model, we pass information about the namespace @@ -4979,7 +4978,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): # be resolved correctly as well. # self.__class__.__module__ handles the case when the Runnable is # being sub-classed in a different module. - __module_name=self.__class__.__module__, + module_name=self.__class__.__module__, ) @property diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 472e2d95b6d..da297ec3bca 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -22,9 +22,9 @@ from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.utils import ( ConfigurableFieldSpec, Output, - create_model, get_unique_config_specs, ) +from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: from langchain_core.language_models.base import LanguageModelLike @@ -386,10 +386,15 @@ class RunnableWithMessageHistory(RunnableBindingBase): elif self.input_messages_key: fields[self.input_messages_key] = (Sequence[BaseMessage], ...) else: - fields["__root__"] = (Sequence[BaseMessage], ...) - return create_model( # type: ignore[call-overload] + return create_model_v2( + "RunnableWithChatHistoryInput", + module_name=self.__class__.__module__, + root=(Sequence[BaseMessage], ...), + ) + return create_model_v2( # type: ignore[call-overload] "RunnableWithChatHistoryInput", - **fields, + field_definitions=fields, + module_name=self.__class__.__module__, ) @property @@ -419,10 +424,10 @@ class RunnableWithMessageHistory(RunnableBindingBase): if inspect.isclass(root_type) and issubclass(root_type, BaseModel): return root_type - return create_model( + return create_model_v2( "RunnableWithChatHistoryOutput", - __root__=root_type, - __module_name=self.__class__.__module__, + root=root_type, + module_name=self.__class__.__module__, ) def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool: diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index bdec3b7205e..a613533674d 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -41,10 +41,10 @@ from langchain_core.runnables.graph import Graph from langchain_core.runnables.utils import ( AddableDict, ConfigurableFieldSpec, - create_model, ) from langchain_core.utils.aiter import atee, py_anext from langchain_core.utils.iter import safetee +from langchain_core.utils.pydantic import create_model_v2 if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -442,9 +442,8 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): for name, field_info in map_output_schema.model_fields.items(): fields[name] = (field_info.annotation, field_info.default) - return create_model( # type: ignore[call-overload] - "RunnableAssignOutput", - **fields, + return create_model_v2( # type: ignore[call-overload] + "RunnableAssignOutput", field_definitions=fields ) elif not issubclass(map_output_schema, RootModel): # ie. only map output is a dict diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index eb03d11346f..20d82be3fd6 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -6,7 +6,6 @@ import ast import asyncio import inspect import textwrap -import warnings from functools import lru_cache from inspect import signature from itertools import groupby @@ -26,24 +25,17 @@ from typing import ( Protocol, Sequence, Set, - Type, TypeVar, Union, - cast, ) -from pydantic import BaseModel, ConfigDict, PydanticDeprecationWarning, RootModel -from pydantic import create_model as _create_model_base # pydantic :ignore -from pydantic.json_schema import ( - DEFAULT_REF_TEMPLATE, - GenerateJsonSchema, - JsonSchemaMode, -) -from pydantic.v1 import BaseModel as BaseModelV1 from typing_extensions import TypeGuard from langchain_core.runnables.schema import StreamEvent +# Re-export create-model for backwards compatibility +from langchain_core.utils.pydantic import create_model as create_model + Input = TypeVar("Input", contravariant=True) # Output type should implement __concat__, as eg str, list, dict do Output = TypeVar("Output", covariant=True) @@ -707,141 +699,6 @@ class _RootEventFilter: return include -_SchemaConfig = ConfigDict(arbitrary_types_allowed=True, frozen=True) - -NO_DEFAULT = object() - - -def _create_root_model( - name: str, - type_: Any, - module_name: Optional[str] = None, - default_: object = NO_DEFAULT, -) -> Type[BaseModel]: - """Create a base class.""" - - def schema( - cls: Type[BaseModel], - by_alias: bool = True, - ref_template: str = DEFAULT_REF_TEMPLATE, - ) -> Dict[str, Any]: - # Complains about schema not being defined in superclass - schema_ = super(cls, cls).schema( # type: ignore[misc] - by_alias=by_alias, ref_template=ref_template - ) - schema_["title"] = name - return schema_ - - def model_json_schema( - cls: Type[BaseModel], - by_alias: bool = True, - ref_template: str = DEFAULT_REF_TEMPLATE, - schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, - mode: JsonSchemaMode = "validation", - ) -> Dict[str, Any]: - # Complains about model_json_schema not being defined in superclass - schema_ = super(cls, cls).model_json_schema( # type: ignore[misc] - by_alias=by_alias, - ref_template=ref_template, - schema_generator=schema_generator, - mode=mode, - ) - schema_["title"] = name - return schema_ - - base_class_attributes = { - "__annotations__": {"root": type_}, - "model_config": ConfigDict(arbitrary_types_allowed=True), - "schema": classmethod(schema), - "model_json_schema": classmethod(model_json_schema), - "__module__": module_name or "langchain_core.runnables.utils", - } - - if default_ is not NO_DEFAULT: - base_class_attributes["root"] = default_ - with warnings.catch_warnings(): - if isinstance(type_, type) and issubclass(type_, BaseModelV1): - warnings.filterwarnings( - action="ignore", category=PydanticDeprecationWarning - ) - custom_root_type = type(name, (RootModel,), base_class_attributes) - return cast(Type[BaseModel], custom_root_type) - - -@lru_cache(maxsize=256) -def _create_root_model_cached( - __model_name: str, - type_: Any, - default_: object = NO_DEFAULT, - module_name: Optional[str] = None, -) -> Type[BaseModel]: - return _create_root_model( - __model_name, type_, default_=default_, module_name=module_name - ) - - -def create_model( - __model_name: str, - __module_name: Optional[str] = None, - **field_definitions: Any, -) -> Type[BaseModel]: - """Create a pydantic model with the given field definitions. - - Args: - __model_name: The name of the model. - __module_name: The name of the module where the model is defined. - This is used by Pydantic to resolve any forward references. - **field_definitions: The field definitions for the model. - - Returns: - Type[BaseModel]: The created model. - """ - - # Move this to caching path - if "__root__" in field_definitions: - if len(field_definitions) > 1: - raise NotImplementedError( - "When specifying __root__ no other " - f"fields should be provided. Got {field_definitions}" - ) - - arg = field_definitions["__root__"] - if isinstance(arg, tuple): - kwargs = {"type_": arg[0], "default_": arg[1]} - else: - kwargs = {"type_": arg} - - try: - named_root_model = _create_root_model_cached( - __model_name, module_name=__module_name, **kwargs - ) - except TypeError: - # something in the arguments into _create_root_model_cached is not hashable - named_root_model = _create_root_model( - __model_name, - module_name=__module_name, - **kwargs, - ) - return named_root_model - try: - return _create_model_cached(__model_name, **field_definitions) - except TypeError: - # something in field definitions is not hashable - return _create_model_base( - __model_name, __config__=_SchemaConfig, **field_definitions - ) - - -@lru_cache(maxsize=256) -def _create_model_cached( - __model_name: str, - **field_definitions: Any, -) -> Type[BaseModel]: - return _create_model_base( - __model_name, __config__=_SchemaConfig, **field_definitions - ) - - def is_async_generator( func: Any, ) -> TypeGuard[Callable[..., AsyncIterator]]: diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index e739532b42b..9ec0cfaa11e 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -5,12 +5,38 @@ from __future__ import annotations import inspect import textwrap import warnings -from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union, overload +from contextlib import nullcontext +from functools import lru_cache, wraps +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, + cast, + overload, +) import pydantic -from pydantic import BaseModel, PydanticDeprecationWarning, root_validator -from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from pydantic import ( + BaseModel, + ConfigDict, + PydanticDeprecationWarning, + RootModel, + root_validator, +) +from pydantic import ( + create_model as _create_model_base, +) +from pydantic.json_schema import ( + DEFAULT_REF_TEMPLATE, + GenerateJsonSchema, + JsonSchemaMode, + JsonSchemaValue, +) from pydantic_core import core_schema @@ -355,3 +381,249 @@ elif PYDANTIC_MAJOR_VERSION == 1: return model.__fields__ # type: ignore else: raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}") + +_SchemaConfig = ConfigDict( + arbitrary_types_allowed=True, frozen=True, protected_namespaces=() +) + +NO_DEFAULT = object() + + +def _create_root_model( + name: str, + type_: Any, + module_name: Optional[str] = None, + default_: object = NO_DEFAULT, +) -> Type[BaseModel]: + """Create a base class.""" + + def schema( + cls: Type[BaseModel], + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + ) -> Dict[str, Any]: + # Complains about schema not being defined in superclass + schema_ = super(cls, cls).schema( # type: ignore[misc] + by_alias=by_alias, ref_template=ref_template + ) + schema_["title"] = name + return schema_ + + def model_json_schema( + cls: Type[BaseModel], + by_alias: bool = True, + ref_template: str = DEFAULT_REF_TEMPLATE, + schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, + mode: JsonSchemaMode = "validation", + ) -> Dict[str, Any]: + # Complains about model_json_schema not being defined in superclass + schema_ = super(cls, cls).model_json_schema( # type: ignore[misc] + by_alias=by_alias, + ref_template=ref_template, + schema_generator=schema_generator, + mode=mode, + ) + schema_["title"] = name + return schema_ + + base_class_attributes = { + "__annotations__": {"root": type_}, + "model_config": ConfigDict(arbitrary_types_allowed=True), + "schema": classmethod(schema), + "model_json_schema": classmethod(model_json_schema), + "__module__": module_name or "langchain_core.runnables.utils", + } + + if default_ is not NO_DEFAULT: + base_class_attributes["root"] = default_ + with warnings.catch_warnings(): + if isinstance(type_, type) and issubclass(type_, BaseModelV1): + warnings.filterwarnings( + action="ignore", category=PydanticDeprecationWarning + ) + custom_root_type = type(name, (RootModel,), base_class_attributes) + return cast(Type[BaseModel], custom_root_type) + + +@lru_cache(maxsize=256) +def _create_root_model_cached( + model_name: str, + type_: Any, + *, + module_name: Optional[str] = None, + default_: object = NO_DEFAULT, +) -> Type[BaseModel]: + return _create_root_model( + model_name, type_, default_=default_, module_name=module_name + ) + + +@lru_cache(maxsize=256) +def _create_model_cached( + __model_name: str, + **field_definitions: Any, +) -> Type[BaseModel]: + return _create_model_base( + __model_name, __config__=_SchemaConfig, **field_definitions + ) + + +def create_model( + __model_name: str, + __module_name: Optional[str] = None, + **field_definitions: Any, +) -> Type[BaseModel]: + """Create a pydantic model with the given field definitions. + + Please use create_model_v2 instead of this function. + + Args: + __model_name: The name of the model. + __module_name: The name of the module where the model is defined. + This is used by Pydantic to resolve any forward references. + **field_definitions: The field definitions for the model. + + Returns: + Type[BaseModel]: The created model. + """ + kwargs = {} + if "__root__" in field_definitions: + kwargs["root"] = field_definitions.pop("__root__") + + return create_model_v2( + __model_name, + module_name=__module_name, + field_definitions=field_definitions, + **kwargs, + ) + + +# Deprecated and not used by the code base. +_OK_TO_OVERWRITE = { + "construct", + "copy", + "dict", + "from_orm", + "json", + "parse_file", + "parse_obj", + "parse_raw", + "schema", + "schema_json", + "update_forward_refs", + "validate", +} + +# These are reserved by pydantic. +_RESERVED_NAMES = { + "model_computed_fields", + "model_config", + "model_construct", + "model_copy", + "model_dump", + "model_dump_json", + "model_extra", + "model_fields", + "model_fields_set", + "model_json_schema", + "model_parametrized_name", + "model_post_init", + "model_rebuild", + "model_validate", + "model_validate_json", + "model_validate_strings", +} + + +def create_model_v2( + model_name: str, + *, + module_name: Optional[str] = None, + field_definitions: Optional[Dict[str, Any]] = None, + root: Optional[Any] = None, +) -> Type[BaseModel]: + """Create a pydantic model with the given field definitions. + + Attention: + Please do not use outside of langchain packages. This API + is subject to change at any time. + + Args: + model_name: The name of the model. + module_name: The name of the module where the model is defined. + This is used by Pydantic to resolve any forward references. + field_definitions: The field definitions for the model. + root: Type for a root model (RootModel) + + Returns: + Type[BaseModel]: The created model. + """ + field_definitions = cast(Dict[str, Any], field_definitions or {}) # type: ignore[no-redef] + + if root: + if field_definitions: + raise NotImplementedError( + "When specifying __root__ no other " + f"fields should be provided. Got {field_definitions}" + ) + + if isinstance(root, tuple): + kwargs = {"type_": root[0], "default_": root[1]} + else: + kwargs = {"type_": root} + + try: + named_root_model = _create_root_model_cached( + model_name, module_name=module_name, **kwargs + ) + except TypeError: + # something in the arguments into _create_root_model_cached is not hashable + named_root_model = _create_root_model( + model_name, + module_name=module_name, + **kwargs, + ) + return named_root_model + + # No root, just field definitions + names = set(field_definitions.keys()) + + if _RESERVED_NAMES & names: + raise ValueError( + f"The following names are reserved by Pydantic: {_RESERVED_NAMES & names} " + f"and cannot be used as a field name. Try to use a different name." + ) + + # Likely common names that Pydantic will throw a run time warning about, + # but these names should be safe to override. + if _OK_TO_OVERWRITE & names: + # Capture warnings + capture_warnings = True + else: + capture_warnings = False + + for name in names: + # Also if any non-reserved name is used (e.g., model_id or model_name) + if name.startswith("model"): + capture_warnings = True + + if name.startswith("_"): # Private attribute + # Pydantic 2 treats fields starting with `_` as private attributes. + # For now, we will raise an error if a field name starts with `_`. + # We will try to remove this restriction in the future. + raise ValueError( + f"Unable to use the field name {name} as " + f"it is prefixed with a `_` attribute. " + f"Please remove the `_` prefix." + ) + + with warnings.catch_warnings() if capture_warnings else nullcontext(): # type: ignore[attr-defined] + if capture_warnings: + warnings.filterwarnings(action="ignore") + try: + return _create_model_cached(model_name, **field_definitions) + except TypeError: + # something in field definitions is not hashable + return _create_model_base( + model_name, __config__=_SchemaConfig, **field_definitions + ) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index a5145fbfaf9..d3562fdf31b 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -864,3 +864,34 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: ) assert dumpd(template) == snapshot() assert load(dumpd(template)) == template + + +def test_chat_prompt_template_variable_names() -> None: + """This test was written for an edge case that triggers a warning from Pydantic. + + Verify that no run time warnings are raised. + """ + with pytest.warns(None) as record: # type: ignore + prompt = ChatPromptTemplate([("system", "{schema}")]) + prompt.get_input_schema() + + if record: + error_msg = [] + for warning in record: + error_msg.append( + f"Warning type: {warning.category.__name__}, " + f"Warning message: {warning.message}, " + f"Warning location: {warning.filename}:{warning.lineno}" + ) + msg = "\n".join(error_msg) + else: + msg = "" + + assert list(record) == [], msg + + # Verify value errors raised from illegal names + with pytest.raises(ValueError): + ChatPromptTemplate([("system", "{_private}")]).get_input_schema() + + with pytest.raises(ValueError): + ChatPromptTemplate([("system", "{model_json_schema}")]).get_input_schema() diff --git a/libs/core/tests/unit_tests/runnables/test_imports.py b/libs/core/tests/unit_tests/runnables/test_imports.py index 12b1a80d1bf..e0a6ac47bb2 100644 --- a/libs/core/tests/unit_tests/runnables/test_imports.py +++ b/libs/core/tests/unit_tests/runnables/test_imports.py @@ -35,3 +35,9 @@ EXPECTED_ALL = [ def test_all_imports() -> None: assert set(__all__) == set(EXPECTED_ALL) + + +def test_imports_for_specific_funcs() -> None: + """Test that a few specific imports in more internal namespaces.""" + # create_model implementation has been moved to langchain_core.utils.pydantic + from langchain_core.runnables.utils import create_model # noqa diff --git a/libs/core/tests/unit_tests/utils/test_pydantic.py b/libs/core/tests/unit_tests/utils/test_pydantic.py index c49c5050a7b..465273950bf 100644 --- a/libs/core/tests/unit_tests/utils/test_pydantic.py +++ b/libs/core/tests/unit_tests/utils/test_pydantic.py @@ -8,6 +8,7 @@ from pydantic import ConfigDict from langchain_core.utils.pydantic import ( PYDANTIC_MAJOR_VERSION, _create_subset_model_v2, + create_model_v2, get_fields, is_basemodel_instance, is_basemodel_subclass, @@ -194,3 +195,44 @@ def test_fields_pydantic_v1_from_2() -> None: fields = get_fields(Foo) assert fields == {"x": Foo.__fields__["x"]} + + +def test_create_model_v2() -> None: + """Test that create model v2 works as expected.""" + + with pytest.warns(None) as record: # type: ignore + foo = create_model_v2("Foo", field_definitions={"a": (int, None)}) + foo.model_json_schema() + + assert list(record) == [] + + # schema is used by pydantic, but OK to re-use + with pytest.warns(None) as record: # type: ignore + foo = create_model_v2("Foo", field_definitions={"schema": (int, None)}) + foo.model_json_schema() + + assert list(record) == [] + + # From protected namespaces, but definitely OK to use. + with pytest.warns(None) as record: # type: ignore + foo = create_model_v2("Foo", field_definitions={"model_id": (int, None)}) + foo.model_json_schema() + + assert list(record) == [] + + # Used by pydantic, not OK to re-use + with pytest.raises(ValueError): + create_model_v2("Foo", field_definitions={"model_json_schema": (int, None)}) + + # Private attributes raise an error for now since pydantic 2 considers them + # to be private attributes. + with pytest.raises(ValueError): + create_model_v2("Foo", field_definitions={"_a": (int, None)}) + + with pytest.warns(None) as record: # type: ignore + # Verify that we can use non-English characters + field_name = "もしもし" + foo = create_model_v2("Foo", field_definitions={field_name: (int, None)}) + foo.model_json_schema() + + assert list(record) == []