From 31433249840147af1c3af2f935a4cbe93e57d51b Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 31 Oct 2023 13:22:54 +0000 Subject: [PATCH] Improve Runnable type inference for input_schemas (#12630) - Prefer lambda type annotations over inferred dict schema - For sequences that start with RunnableAssign infer seq input type as "input type of 2nd item in sequence - output type of runnable assign" --- .../langchain/schema/runnable/base.py | 20 +++++ .../langchain/schema/runnable/passthrough.py | 4 + .../schema/runnable/test_runnable.py | 85 +++++++++++++++++++ 3 files changed, 109 insertions(+) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index e1abeef5e9d..6971bfd68d6 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -1109,6 +1109,23 @@ class RunnableSequence(RunnableSerializable[Input, Output]): def get_input_schema( self, config: Optional[RunnableConfig] = None ) -> Type[BaseModel]: + from langchain.schema.runnable.passthrough import RunnableAssign + + if isinstance(self.first, RunnableAssign): + first = cast(RunnableAssign, self.first) + next_ = self.middle[0] if self.middle else self.last + next_input_schema = next_.get_input_schema(config) + if not next_input_schema.__custom_root_type__: + # it's a dict as expected + return create_model( # type: ignore[call-overload] + "RunnableSequenceInput", + **{ + k: (v.annotation, v.default) + for k, v in next_input_schema.__fields__.items() + if k not in first.mapper.steps + }, + ) + return self.first.get_input_schema(config) def get_output_schema( @@ -2152,6 +2169,9 @@ class RunnableLambda(Runnable[Input, Output]): else: return create_model("RunnableLambdaInput", __root__=(List[Any], None)) + if self.InputType != Any: + return super().get_input_schema(config) + if dict_keys := get_function_first_arg_dict_keys(func): return create_model( "RunnableLambdaInput", diff --git a/libs/langchain/langchain/schema/runnable/passthrough.py b/libs/langchain/langchain/schema/runnable/passthrough.py index 117edb90740..abe211651d0 100644 --- a/libs/langchain/langchain/schema/runnable/passthrough.py +++ b/libs/langchain/langchain/schema/runnable/passthrough.py @@ -326,6 +326,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]): for k, v in s.__fields__.items() }, ) + elif not map_output_schema.__custom_root_type__: + # ie. only map output is a dict + # ie. input type is either unknown or inferred incorrectly + return map_output_schema return super().get_output_schema(config) diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 065af5b6056..583059ccb1f 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -18,6 +18,7 @@ import pytest from freezegun import freeze_time from pytest_mock import MockerFixture from syrupy import SnapshotAssertion +from typing_extensions import TypedDict from langchain.callbacks.manager import ( Callbacks, @@ -508,6 +509,41 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: } +def test_passthrough_assign_schema() -> None: + retriever = FakeRetriever() # str -> List[Document] + prompt = PromptTemplate.from_template("{context} {question}") + fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]] + + seq_w_assign: Runnable = ( + RunnablePassthrough.assign(context=itemgetter("question") | retriever) + | prompt + | fake_llm + ) + + assert seq_w_assign.input_schema.schema() == { + "properties": {"question": {"title": "Question", "type": "string"}}, + "title": "RunnableSequenceInput", + "type": "object", + } + assert seq_w_assign.output_schema.schema() == { + "title": "FakeListLLMOutput", + "type": "string", + } + + invalid_seq_w_assign: Runnable = ( + RunnablePassthrough.assign(context=itemgetter("question") | retriever) + | fake_llm + ) + + # fallback to RunnableAssign.input_schema if next runnable doesn't have + # expected dict input_schema + assert invalid_seq_w_assign.input_schema.schema() == { + "properties": {"question": {"title": "Question"}}, + "title": "RunnableParallelInput", + "type": "object", + } + + @pytest.mark.skipif( sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run." ) @@ -565,6 +601,55 @@ def test_lambda_schemas() -> None: }, } + class InputType(TypedDict): + variable_name: str + yo: int + + class OutputType(TypedDict): + hello: str + bye: str + byebye: int + + async def aget_values_typed(input: InputType) -> OutputType: + return { + "hello": input["variable_name"], + "bye": input["variable_name"], + "byebye": input["yo"], + } + + assert RunnableLambda(aget_values_typed).input_schema.schema() == { # type: ignore[arg-type] + "title": "RunnableLambdaInput", + "$ref": "#/definitions/InputType", + "definitions": { + "InputType": { + "properties": { + "variable_name": {"title": "Variable " "Name", "type": "string"}, + "yo": {"title": "Yo", "type": "integer"}, + }, + "required": ["variable_name", "yo"], + "title": "InputType", + "type": "object", + } + }, + } + + assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type] + "title": "RunnableLambdaOutput", + "$ref": "#/definitions/OutputType", + "definitions": { + "OutputType": { + "properties": { + "bye": {"title": "Bye", "type": "string"}, + "byebye": {"title": "Byebye", "type": "integer"}, + "hello": {"title": "Hello", "type": "string"}, + }, + "required": ["hello", "bye", "byebye"], + "title": "OutputType", + "type": "object", + } + }, + } + def test_with_types_with_type_generics() -> None: """Verify that with_types works if we use things like List[int]"""