Merge branch 'v0.3rc' into bagatur/09-10/v0.3_merge_master

This commit is contained in:
Bagatur
2024-09-10 12:44:49 -07:00
committed by GitHub
4 changed files with 194 additions and 5 deletions

View File

@@ -348,6 +348,14 @@ class Runnable(Generic[Input, Output], ABC):
return create_model(
self.get_name("Input"),
__root__=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
def get_input_jsonschema(
@@ -408,6 +416,14 @@ class Runnable(Generic[Input, Output], ABC):
return create_model(
self.get_name("Output"),
__root__=root_type,
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
def get_output_jsonschema(
@@ -4046,6 +4062,29 @@ class RunnableGenerator(Runnable[Input, Output]):
except ValueError:
return Any
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.InputType
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Input"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
@property
def OutputType(self) -> Any:
func = getattr(self, "_transform", None) or self._atransform
@@ -4059,6 +4098,28 @@ class RunnableGenerator(Runnable[Input, Output]):
except ValueError:
return Any
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable generator, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.OutputType
func = getattr(self, "_transform", None) or self._atransform
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableGenerator):
if hasattr(self, "_transform") and hasattr(other, "_transform"):
@@ -4307,9 +4368,14 @@ class RunnableLambda(Runnable[Input, Output]):
# It's a dict, lol
return create_model(self.get_name("Input"), **fields)
else:
module = getattr(func, "__module__", None)
return create_model(
self.get_name("Input"),
__root__=List[Any],
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
if self.InputType != Any:
@@ -4346,6 +4412,28 @@ class RunnableLambda(Runnable[Input, Output]):
except ValueError:
return Any
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
# Override the default implementation.
# For a runnable lambda, we need to bring to provide the
# module of the underlying function when creating the model.
root_type = self.OutputType
func = getattr(self, "func", None) or self.afunc
module = getattr(func, "__module__", None)
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
self.get_name("Output"),
__root__=root_type,
# To create the schema, we need to provide the module
# where the underlying function is defined.
# This allows pydantic to resolve type annotations appropriately.
__module_name=module,
)
@property
def deps(self) -> List[Runnable]:
"""The dependencies of this Runnable.
@@ -4863,6 +4951,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
List[self.bound.get_input_schema(config)], # type: ignore
None,
),
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
@property
@@ -4876,6 +4972,14 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
return create_model(
self.get_name("Output"),
__root__=List[schema], # type: ignore[valid-type]
# create model needs access to appropriate type annotations to be
# able to construct the pydantic model.
# When we create the model, we pass information about the namespace
# where the model is being created, so the type annotations can
# be resolved correctly as well.
# self.__class__.__module__ handles the case when the Runnable is
# being sub-classed in a different module.
__module_name=self.__class__.__module__,
)
@property

View File

@@ -21,6 +21,7 @@ from langchain_core.runnables.base import Runnable, RunnableBindingBase, Runnabl
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.runnables.utils import (
ConfigurableFieldSpec,
Output,
create_model,
get_unique_config_specs,
)
@@ -362,6 +363,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
history_factory_config=_config_specs,
**kwargs,
)
self._history_chain = history_chain
@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
@@ -393,6 +395,39 @@ class RunnableWithMessageHistory(RunnableBindingBase):
**fields,
)
@property
def OutputType(self) -> Type[Output]:
output_type = self._history_chain.OutputType
return output_type
def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""Get a pydantic model that can be used to validate output to the Runnable.
Runnables that leverage the configurable_fields and configurable_alternatives
methods will have a dynamic output schema that depends on which
configuration the Runnable is invoked with.
This method allows to get an output schema for a specific configuration.
Args:
config: A config to use when generating the schema.
Returns:
A pydantic model that can be used to validate output.
"""
root_type = self.OutputType
if inspect.isclass(root_type) and issubclass(root_type, BaseModel):
return root_type
return create_model(
"RunnableWithChatHistoryOutput",
__root__=root_type,
__module_name=self.__class__.__module__,
)
def _is_not_async(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> bool:
return False

View File

@@ -713,7 +713,10 @@ NO_DEFAULT = object()
def _create_root_model(
name: str, type_: Any, default_: object = NO_DEFAULT
name: str,
type_: Any,
module_name: Optional[str] = None,
default_: object = NO_DEFAULT,
) -> Type[BaseModel]:
"""Create a base class."""
@@ -751,7 +754,7 @@ def _create_root_model(
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
"__module__": "langchain_core.runnables.utils",
"__module__": module_name or "langchain_core.runnables.utils",
}
if default_ is not NO_DEFAULT:
@@ -770,18 +773,24 @@ def _create_root_model_cached(
__model_name: str,
type_: Any,
default_: object = NO_DEFAULT,
module_name: Optional[str] = None,
) -> Type[BaseModel]:
return _create_root_model(__model_name, type_, default_)
return _create_root_model(
__model_name, type_, default_=default_, module_name=module_name
)
def create_model(
__model_name: str,
__module_name: Optional[str] = None,
**field_definitions: Any,
) -> Type[BaseModel]:
"""Create a pydantic model with the given field definitions.
Args:
__model_name: The name of the model.
__module_name: The name of the module where the model is defined.
This is used by Pydantic to resolve any forward references.
**field_definitions: The field definitions for the model.
Returns:
@@ -803,10 +812,16 @@ def create_model(
kwargs = {"type_": arg}
try:
named_root_model = _create_root_model_cached(__model_name, **kwargs)
named_root_model = _create_root_model_cached(
__model_name, module_name=__module_name, **kwargs
)
except TypeError:
# something in the arguments into _create_root_model_cached is not hashable
named_root_model = _create_root_model(__model_name, **kwargs)
named_root_model = _create_root_model(
__model_name,
module_name=__module_name,
**kwargs,
)
return named_root_model
try:
return _create_model_cached(__model_name, **field_definitions)

View File

@@ -454,6 +454,41 @@ def test_get_input_schema_input_dict() -> None:
)
def test_get_output_schema() -> None:
"""Test get output schema."""
runnable = RunnableLambda(
lambda input: {
"output": [
AIMessage(
content="you said: "
+ "\n".join(
[
str(m.content)
for m in input["history"]
if isinstance(m, HumanMessage)
]
+ [input["input"]]
)
)
]
}
)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(
runnable,
get_session_history,
input_messages_key="input",
history_messages_key="history",
output_messages_key="output",
)
output_type = with_history.get_output_schema()
assert _schema(output_type) == {
"title": "RunnableWithChatHistoryOutput",
"type": "object",
}
def test_get_input_schema_input_messages() -> None:
from pydantic import RootModel