Replace __fields__ with model_fields

This commit is contained in:
Eugene Yurtsev
2024-08-07 15:44:31 -04:00
parent 22957311fe
commit 76b6ee290d

View File

@@ -35,6 +35,7 @@ from typing import (
overload,
)
from pydantic import BaseModel, ConfigDict, Field, RootModel
from typing_extensions import Literal, get_args, get_type_hints
from langchain_core._api import beta_decorator
@@ -44,7 +45,6 @@ from langchain_core.load.serializable import (
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field, RootModel
from langchain_core.runnables.config import (
RunnableConfig,
_set_config_context,
@@ -83,7 +83,6 @@ from langchain_core.runnables.utils import (
)
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee
from langchain_core.utils.pydantic import is_basemodel_subclass
if TYPE_CHECKING:
from langchain_core.callbacks.manager import (
@@ -2408,10 +2407,10 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
from langchain_core.runnables.configurable import RunnableConfigurableFields
for key in kwargs:
if key not in self.__fields__:
if key not in self.model_fields:
raise ValueError(
f"Configuration key {key} not found in {self}: "
f"available keys are {self.__fields__.keys()}"
f"available keys are {self.model_fields.keys()}"
)
return RunnableConfigurableFields(default=self, fields=kwargs)
@@ -2492,7 +2491,7 @@ def _seq_input_schema(
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
for k, v in next_input_schema.model_fields.items()
if k not in first.mapper.steps__
},
)
@@ -2520,11 +2519,11 @@ def _seq_output_schema(
**{
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
},
**{
k: (v.annotation, v.default)
for k, v in mapper_output_schema.__fields__.items()
for k, v in mapper_output_schema.model_fields.items()
},
},
)
@@ -2537,12 +2536,12 @@ def _seq_output_schema(
"RunnableSequenceOutput",
**{
k: (v.annotation, v.default)
for k, v in prev_output_schema.__fields__.items()
for k, v in prev_output_schema.model_fields.items()
if k in last.keys
},
)
else:
field = prev_output_schema.__fields__[last.keys]
field = prev_output_schema.model_fields[last.keys]
return create_model( # type: ignore[call-overload]
"RunnableSequenceOutput",
__root__=(field.annotation, field.default),
@@ -2704,8 +2703,9 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
"""
return True
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Type[Input]:
@@ -3442,8 +3442,9 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
@@ -3490,7 +3491,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
**{
k: (v.annotation, v.default)
for step in self.steps__.values()
for k, v in step.get_input_schema(config).__fields__.items()
for k, v in step.get_input_schema(config).model_fields.items()
if k != "__root__"
},
)
@@ -4764,8 +4765,9 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
@property
def InputType(self) -> Any:
@@ -5009,8 +5011,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
The type can be a pydantic model, or a type annotation (e.g., `List[str]`).
"""
class Config:
arbitrary_types_allowed = True
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
def __init__(
self,
@@ -5346,7 +5349,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item
RunnableBindingBase.update_forward_refs()
RunnableBindingBase.model_rebuild()
class RunnableBinding(RunnableBindingBase[Input, Output]):