From d0ce374731863747df4c89addc0d1ec14f95c98c Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Sun, 22 Oct 2023 17:26:48 +0100 Subject: [PATCH] Allow specifying custom input/output schemas for runnables with .with_types() (#12083) --- .../langchain/schema/runnable/base.py | 69 ++++++++++++++++++- .../runnable/__snapshots__/test_runnable.ambr | 4 +- .../schema/runnable/test_runnable.py | 21 ++++++ 3 files changed, 90 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 818bcf3c38a..9968595677f 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -585,6 +585,22 @@ class Runnable(Generic[Input, Output], ABC): kwargs={}, ) + def with_types( + self, + *, + input_type: Optional[Type[Input]] = None, + output_type: Optional[Type[Output]] = None, + ) -> Runnable[Input, Output]: + """ + Bind input and output types to a Runnable, returning a new Runnable. + """ + return RunnableBinding( + bound=self, + custom_input_type=input_type, + custom_output_type=output_type, + kwargs={}, + ) + def with_retry( self, *, @@ -2277,6 +2293,11 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]): def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: return RunnableEach(bound=self.bound.bind(**kwargs)) + def with_config( + self, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> RunnableEach[Input, Output]: + return RunnableEach(bound=self.bound.with_config(config, **kwargs)) + def _invoke( self, inputs: List[Input], @@ -2321,6 +2342,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]): config: RunnableConfig = Field(default_factory=dict) + custom_input_type: Optional[Union[Type[Input], BaseModel]] = None + + custom_output_type: Optional[Union[Type[Output], BaseModel]] = None + class Config: arbitrary_types_allowed = True @@ -2330,6 +2355,8 @@ class RunnableBinding(RunnableSerializable[Input, Output]): bound: Runnable[Input, Output], kwargs: Mapping[str, Any], config: Optional[RunnableConfig] = None, + custom_input_type: Optional[Union[Type[Input], BaseModel]] = None, + custom_output_type: Optional[Union[Type[Output], BaseModel]] = None, **other_kwargs: Any, ) -> None: config = config or {} @@ -2342,24 +2369,43 @@ class RunnableBinding(RunnableSerializable[Input, Output]): f"Configurable key '{key}' not found in runnable with" f" config keys: {allowed_keys}" ) - super().__init__(bound=bound, kwargs=kwargs, config=config, **other_kwargs) + super().__init__( + bound=bound, + kwargs=kwargs, + config=config, + custom_input_type=custom_input_type, + custom_output_type=custom_output_type, + **other_kwargs, + ) @property def InputType(self) -> Type[Input]: - return self.bound.InputType + return ( + cast(Type[Input], self.custom_input_type) + if self.custom_input_type is not None + else self.bound.InputType + ) @property def OutputType(self) -> Type[Output]: - return self.bound.OutputType + return ( + cast(Type[Output], self.custom_output_type) + if self.custom_output_type is not None + else self.bound.OutputType + ) def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: + if self.custom_input_type is not None: + return super().get_input_schema(config) return self.bound.get_input_schema(merge_configs(self.config, config)) def get_output_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: + if self.custom_output_type is not None: + return super().get_output_schema(config) return self.bound.get_output_schema(merge_configs(self.config, config)) @property @@ -2394,6 +2440,23 @@ class RunnableBinding(RunnableSerializable[Input, Output]): config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}), ) + def with_types( + self, + input_type: Optional[Union[Type[Input], BaseModel]] = None, + output_type: Optional[Union[Type[Output], BaseModel]] = None, + ) -> Runnable[Input, Output]: + return self.__class__( + bound=self.bound, + kwargs=self.kwargs, + config=self.config, + custom_input_type=input_type + if input_type is not None + else self.custom_input_type, + custom_output_type=output_type + if output_type is not None + else self.custom_output_type, + ) + def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]: return self.__class__( bound=self.bound.with_retry(**kwargs), diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index e78280380bf..c26a88e57ea 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -3747,7 +3747,9 @@ "Thought:" ] }, - "config": {} + "config": {}, + "custom_input_type": null, + "custom_output_type": null } }, "llm": { diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 5bc55cca31c..5632a453952 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -39,6 +39,7 @@ from langchain.prompts.chat import ( MessagesPlaceholder, SystemMessagePromptTemplate, ) +from langchain.pydantic_v1 import BaseModel from langchain.schema.document import Document from langchain.schema.messages import ( AIMessage, @@ -587,6 +588,26 @@ def test_schema_complex_seq() -> None: "type": "string", } + assert chain2.with_types(input_type=str).input_schema.schema() == { + "title": "RunnableBindingInput", + "type": "string", + } + + assert chain2.with_types(input_type=int).output_schema.schema() == { + "title": "StrOutputParserOutput", + "type": "string", + } + + class InputType(BaseModel): + person: str + + assert chain2.with_types(input_type=InputType).input_schema.schema() == { + "title": "InputType", + "type": "object", + "properties": {"person": {"title": "Person", "type": "string"}}, + "required": ["person"], + } + def test_schema_chains() -> None: model = FakeListChatModel(responses=[""])