core[patch]: Propagate module name to create model (#26267)

* This allows pydantic to correctly resolve annotations necessary for
building pydantic models dynamically.
* Makes a small fix for RunnableWithMessageHistory which was fetching
the OutputType from the RunnableLambda that was yielding another
RunnableLambda. This doesn't propagate the output of the RunnableAssign
fully (i.e., with concrete type information etc.)

Resolves issue: https://github.com/langchain-ai/langchain/issues/26250
This commit is contained in:
Eugene Yurtsev
2024-09-10 15:22:56 -04:00
committed by GitHub
parent 622cb7d2cf
commit 7975c1f0ca
4 changed files with 194 additions and 5 deletions

View File

@@ -348,6 +348,14 @@ class Runnable(Generic[Input, Output], ABC):
return create_model(
self.get_name("Input"),
__root__=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
def get_input_jsonschema(
@@ -408,6 +416,14 @@ class Runnable(Generic[Input, Output], ABC):
return create_model(
self.get_name("Output"),
__root__=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
def get_output_jsonschema(
@@ -4046,6 +4062,29 @@ class RunnableGenerator(Runnable[Input, Output]):
except ValueError:
return Any
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.InputType
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Input"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform
@@ -4059,6 +4098,28 @@ class RunnableGenerator(Runnable[Input, Output]):
except ValueError:
return Any
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.OutputType
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
@@ -4307,9 +4368,14 @@ class RunnableLambda(Runnable[Input, Output]):
# It's a dict, lol
return create_model(self.get_name("Input"), **fields)
else:
module = getattr(func, "__module__", None)
return create_model(
self.get_name("Input"),
__root__=List[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
if self.InputType != Any:
@@ -4346,6 +4412,28 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError:
return Any
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable lambda, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.OutputType
func = getattr(self, "func", None) or self.afunc
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
@property
def deps(self) -> List[Runnable]:
"""The dependencies of this Runnable.
@@ -4863,6 +4951,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
@property
@@ -4876,6 +4972,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return create_model(
self.get_name("Output"),
__root__=List[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
@property

View File

@@ -21,6 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableBindingBase, Runnabl
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Output,
create_model,
get_unique_config_specs,
)
@@ -362,6 +363,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_factory_config=_config_specs,
**kwargs,
)
self._history_chain = history_chain
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -393,6 +395,39 @@ class RunnableWithMessageHistory(RunnableBindingBase):
**fields,
)
@property
def OutputType(self) -> Type[Output]:
output_type = self._history_chain.OutputType
return output_type
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives
methods will have a dynamic output schema that depends on which
configuration the Runnable is invoked with.
This method allows to get an output schema for a specific configuration.
Args:
config: A config to use when generating the schema.
Returns:
A pydantic model that can be used to validate output.
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
"RunnableWithChatHistoryOutput",
__root__=root_type,
__module_name=self.__class__.__module__,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

View File

@@ -713,7 +713,10 @@ NO_DEFAULT = object()
def _create_root_model(
name: str, type_: Any, default_: object = NO_DEFAULT
name: str,
type_: Any,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> Type[BaseModel]:
"""Create a base class."""
@@ -751,7 +754,7 @@ def _create_root_model(
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": "langchain_core.runnables.utils",
"__module__": module_name or "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
@@ -770,18 +773,24 @@ def _create_root_model_cached(
__model_name: str,
type_: Any,
default_: object = NO_DEFAULT,
module_name: Optional[str] = None,
) -> Type[BaseModel]:
return _create_root_model(__model_name, type_, default_)
return _create_root_model(
__model_name, type_, default_=default_, module_name=module_name
)
def create_model(
__model_name: str,
__module_name: Optional[str] = None,
**field_definitions: Any,
) -> Type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.
Returns:
@@ -803,10 +812,16 @@ def create_model(
kwargs = {"type_": arg}
try:
named_root_model = _create_root_model_cached(__model_name, **kwargs)
named_root_model = _create_root_model_cached(
__model_name, module_name=__module_name, **kwargs
)
except TypeError:
# something in the arguments into _create_root_model_cached is not hashable
named_root_model = _create_root_model(__model_name, **kwargs)
named_root_model = _create_root_model(
__model_name,
module_name=__module_name,
**kwargs,
)
return named_root_model
try:
return _create_model_cached(__model_name, **field_definitions)

View File

@@ -454,6 +454,41 @@ def test_get_input_schema_input_dict() -> None:
)
def test_get_output_schema() -> None:
"""Test get output schema."""
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
output_type = with_history.get_output_schema()
assert _schema(output_type) == {
"title": "RunnableWithChatHistoryOutput",
"type": "object",
}
def test_get_input_schema_input_messages() -> None:
from pydantic import RootModel