From 76b6ee290da66053b649cf19ae07778fda854b3f Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 7 Aug 2024 15:44:31 -0400 Subject: [PATCH] Replace __fields__ with model_fields --- libs/core/langchain_core/runnables/base.py | 41 ++++++++++++---------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index c3de85026be..223852cc15e 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -35,6 +35,7 @@ from typing import ( overload, ) +from pydantic import BaseModel, ConfigDict, Field, RootModel from typing_extensions import Literal, get_args, get_type_hints from langchain_core._api import beta_decorator @@ -44,7 +45,6 @@ from langchain_core.load.serializable import ( SerializedConstructor, SerializedNotImplemented, ) -from langchain_core.pydantic_v1 import BaseModel, Field, RootModel from langchain_core.runnables.config import ( RunnableConfig, _set_config_context, @@ -83,7 +83,6 @@ from langchain_core.runnables.utils import ( ) from langchain_core.utils.aiter import aclosing, atee, py_anext from langchain_core.utils.iter import safetee -from langchain_core.utils.pydantic import is_basemodel_subclass if TYPE_CHECKING: from langchain_core.callbacks.manager import ( @@ -2408,10 +2407,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]): from langchain_core.runnables.configurable import RunnableConfigurableFields for key in kwargs: - if key not in self.__fields__: + if key not in self.model_fields: raise ValueError( f"Configuration key {key} not found in {self}: " - f"available keys are {self.__fields__.keys()}" + f"available keys are {self.model_fields.keys()}" ) return RunnableConfigurableFields(default=self, fields=kwargs) @@ -2492,7 +2491,7 @@ def _seq_input_schema( "RunnableSequenceInput", **{ k: (v.annotation, v.default) - for k, v in next_input_schema.__fields__.items() + for k, v in next_input_schema.model_fields.items() if k not in first.mapper.steps__ }, ) @@ -2520,11 +2519,11 @@ def _seq_output_schema( **{ **{ k: (v.annotation, v.default) - for k, v in prev_output_schema.__fields__.items() + for k, v in prev_output_schema.model_fields.items() }, **{ k: (v.annotation, v.default) - for k, v in mapper_output_schema.__fields__.items() + for k, v in mapper_output_schema.model_fields.items() }, }, ) @@ -2537,12 +2536,12 @@ def _seq_output_schema( "RunnableSequenceOutput", **{ k: (v.annotation, v.default) - for k, v in prev_output_schema.__fields__.items() + for k, v in prev_output_schema.model_fields.items() if k in last.keys }, ) else: - field = prev_output_schema.__fields__[last.keys] + field = prev_output_schema.model_fields[last.keys] return create_model( # type: ignore[call-overload] "RunnableSequenceOutput", __root__=(field.annotation, field.default), @@ -2704,8 +2703,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]): """ return True - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def InputType(self) -> Type[Input]: @@ -3442,8 +3442,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): """Get the namespace of the langchain object.""" return ["langchain", "schema", "runnable"] - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def get_name( self, suffix: Optional[str] = None, *, name: Optional[str] = None @@ -3490,7 +3491,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]): **{ k: (v.annotation, v.default) for step in self.steps__.values() - for k, v in step.get_input_schema(config).__fields__.items() + for k, v in step.get_input_schema(config).model_fields.items() if k != "__root__" }, ) @@ -4764,8 +4765,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]): bound: Runnable[Input, Output] - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) @property def InputType(self) -> Any: @@ -5009,8 +5011,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): The type can be a pydantic model, or a type annotation (e.g., `List[str]`). """ - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict( + arbitrary_types_allowed=True, + ) def __init__( self, @@ -5346,7 +5349,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): yield item -RunnableBindingBase.update_forward_refs() +RunnableBindingBase.model_rebuild() class RunnableBinding(RunnableBindingBase[Input, Output]):