mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-26 13:59:49 +00:00
core[patch]: Propagate module name to create model (#26267)
* This allows pydantic to correctly resolve annotations necessary for building pydantic models dynamically. * Makes a small fix for RunnableWithMessageHistory which was fetching the OutputType from the RunnableLambda that was yielding another RunnableLambda. This doesn't propagate the output of the RunnableAssign fully (i.e., with concrete type information etc.) Resolves issue: https://github.com/langchain-ai/langchain/issues/26250
This commit is contained in:
@@ -348,6 +348,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=root_type,
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
# self.__class__.__module__ handles the case when the Runnable is
|
||||
# being sub-classed in a different module.
|
||||
__module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
def get_input_jsonschema(
|
||||
@@ -408,6 +416,14 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=root_type,
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
# self.__class__.__module__ handles the case when the Runnable is
|
||||
# being sub-classed in a different module.
|
||||
__module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
def get_output_jsonschema(
|
||||
@@ -4046,6 +4062,29 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable generator, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
root_type = self.InputType
|
||||
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=root_type,
|
||||
# To create the schema, we need to provide the module
|
||||
# where the underlying function is defined.
|
||||
# This allows pydantic to resolve type annotations appropriately.
|
||||
__module_name=module,
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Any:
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
@@ -4059,6 +4098,28 @@ class RunnableGenerator(Runnable[Input, Output]):
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable generator, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
root_type = self.OutputType
|
||||
func = getattr(self, "_transform", None) or self._atransform
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=root_type,
|
||||
# To create the schema, we need to provide the module
|
||||
# where the underlying function is defined.
|
||||
# This allows pydantic to resolve type annotations appropriately.
|
||||
__module_name=module,
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if isinstance(other, RunnableGenerator):
|
||||
if hasattr(self, "_transform") and hasattr(other, "_transform"):
|
||||
@@ -4307,9 +4368,14 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
# It's a dict, lol
|
||||
return create_model(self.get_name("Input"), **fields)
|
||||
else:
|
||||
module = getattr(func, "__module__", None)
|
||||
return create_model(
|
||||
self.get_name("Input"),
|
||||
__root__=List[Any],
|
||||
# To create the schema, we need to provide the module
|
||||
# where the underlying function is defined.
|
||||
# This allows pydantic to resolve type annotations appropriately.
|
||||
__module_name=module,
|
||||
)
|
||||
|
||||
if self.InputType != Any:
|
||||
@@ -4346,6 +4412,28 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
except ValueError:
|
||||
return Any
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
# Override the default implementation.
|
||||
# For a runnable lambda, we need to bring to provide the
|
||||
# module of the underlying function when creating the model.
|
||||
root_type = self.OutputType
|
||||
func = getattr(self, "func", None) or self.afunc
|
||||
module = getattr(func, "__module__", None)
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=root_type,
|
||||
# To create the schema, we need to provide the module
|
||||
# where the underlying function is defined.
|
||||
# This allows pydantic to resolve type annotations appropriately.
|
||||
__module_name=module,
|
||||
)
|
||||
|
||||
@property
|
||||
def deps(self) -> List[Runnable]:
|
||||
"""The dependencies of this Runnable.
|
||||
@@ -4863,6 +4951,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
List[self.bound.get_input_schema(config)], # type: ignore
|
||||
None,
|
||||
),
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
# self.__class__.__module__ handles the case when the Runnable is
|
||||
# being sub-classed in a different module.
|
||||
__module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -4876,6 +4972,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
return create_model(
|
||||
self.get_name("Output"),
|
||||
__root__=List[schema], # type: ignore[valid-type]
|
||||
# create model needs access to appropriate type annotations to be
|
||||
# able to construct the pydantic model.
|
||||
# When we create the model, we pass information about the namespace
|
||||
# where the model is being created, so the type annotations can
|
||||
# be resolved correctly as well.
|
||||
# self.__class__.__module__ handles the case when the Runnable is
|
||||
# being sub-classed in a different module.
|
||||
__module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
@property
|
||||
|
@@ -21,6 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableBindingBase, Runnabl
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.runnables.utils import (
|
||||
ConfigurableFieldSpec,
|
||||
Output,
|
||||
create_model,
|
||||
get_unique_config_specs,
|
||||
)
|
||||
@@ -362,6 +363,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
history_factory_config=_config_specs,
|
||||
**kwargs,
|
||||
)
|
||||
self._history_chain = history_chain
|
||||
|
||||
@property
|
||||
def config_specs(self) -> List[ConfigurableFieldSpec]:
|
||||
@@ -393,6 +395,39 @@ class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
**fields,
|
||||
)
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[Output]:
|
||||
output_type = self._history_chain.OutputType
|
||||
return output_type
|
||||
|
||||
def get_output_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""Get a pydantic model that can be used to validate output to the Runnable.
|
||||
|
||||
Runnables that leverage the configurable_fields and configurable_alternatives
|
||||
methods will have a dynamic output schema that depends on which
|
||||
configuration the Runnable is invoked with.
|
||||
|
||||
This method allows to get an output schema for a specific configuration.
|
||||
|
||||
Args:
|
||||
config: A config to use when generating the schema.
|
||||
|
||||
Returns:
|
||||
A pydantic model that can be used to validate output.
|
||||
"""
|
||||
root_type = self.OutputType
|
||||
|
||||
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
"RunnableWithChatHistoryOutput",
|
||||
__root__=root_type,
|
||||
__module_name=self.__class__.__module__,
|
||||
)
|
||||
|
||||
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
@@ -713,7 +713,10 @@ NO_DEFAULT = object()
|
||||
|
||||
|
||||
def _create_root_model(
|
||||
name: str, type_: Any, default_: object = NO_DEFAULT
|
||||
name: str,
|
||||
type_: Any,
|
||||
module_name: Optional[str] = None,
|
||||
default_: object = NO_DEFAULT,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a base class."""
|
||||
|
||||
@@ -751,7 +754,7 @@ def _create_root_model(
|
||||
"model_config": ConfigDict(arbitrary_types_allowed=True),
|
||||
"schema": classmethod(schema),
|
||||
"model_json_schema": classmethod(model_json_schema),
|
||||
"__module__": "langchain_core.runnables.utils",
|
||||
"__module__": module_name or "langchain_core.runnables.utils",
|
||||
}
|
||||
|
||||
if default_ is not NO_DEFAULT:
|
||||
@@ -770,18 +773,24 @@ def _create_root_model_cached(
|
||||
__model_name: str,
|
||||
type_: Any,
|
||||
default_: object = NO_DEFAULT,
|
||||
module_name: Optional[str] = None,
|
||||
) -> Type[BaseModel]:
|
||||
return _create_root_model(__model_name, type_, default_)
|
||||
return _create_root_model(
|
||||
__model_name, type_, default_=default_, module_name=module_name
|
||||
)
|
||||
|
||||
|
||||
def create_model(
|
||||
__model_name: str,
|
||||
__module_name: Optional[str] = None,
|
||||
**field_definitions: Any,
|
||||
) -> Type[BaseModel]:
|
||||
"""Create a pydantic model with the given field definitions.
|
||||
|
||||
Args:
|
||||
__model_name: The name of the model.
|
||||
__module_name: The name of the module where the model is defined.
|
||||
This is used by Pydantic to resolve any forward references.
|
||||
**field_definitions: The field definitions for the model.
|
||||
|
||||
Returns:
|
||||
@@ -803,10 +812,16 @@ def create_model(
|
||||
kwargs = {"type_": arg}
|
||||
|
||||
try:
|
||||
named_root_model = _create_root_model_cached(__model_name, **kwargs)
|
||||
named_root_model = _create_root_model_cached(
|
||||
__model_name, module_name=__module_name, **kwargs
|
||||
)
|
||||
except TypeError:
|
||||
# something in the arguments into _create_root_model_cached is not hashable
|
||||
named_root_model = _create_root_model(__model_name, **kwargs)
|
||||
named_root_model = _create_root_model(
|
||||
__model_name,
|
||||
module_name=__module_name,
|
||||
**kwargs,
|
||||
)
|
||||
return named_root_model
|
||||
try:
|
||||
return _create_model_cached(__model_name, **field_definitions)
|
||||
|
@@ -454,6 +454,41 @@ def test_get_input_schema_input_dict() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_get_output_schema() -> None:
|
||||
"""Test get output schema."""
|
||||
runnable = RunnableLambda(
|
||||
lambda input: {
|
||||
"output": [
|
||||
AIMessage(
|
||||
content="you said: "
|
||||
+ "\n".join(
|
||||
[
|
||||
str(m.content)
|
||||
for m in input["history"]
|
||||
if isinstance(m, HumanMessage)
|
||||
]
|
||||
+ [input["input"]]
|
||||
)
|
||||
)
|
||||
]
|
||||
}
|
||||
)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(
|
||||
runnable,
|
||||
get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
output_messages_key="output",
|
||||
)
|
||||
output_type = with_history.get_output_schema()
|
||||
|
||||
assert _schema(output_type) == {
|
||||
"title": "RunnableWithChatHistoryOutput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
def test_get_input_schema_input_messages() -> None:
|
||||
from pydantic import RootModel
|
||||
|
||||
|
Reference in New Issue
Block a user