Improve Runnable type inference for input_schemas (#12630)

- Prefer lambda type annotations over inferred dict schema
- For sequences that start with RunnableAssign infer seq input type as
"input type of 2nd item in sequence - output type of runnable assign"
This commit is contained in:
Nuno Campos 2023-10-31 13:22:54 +00:00 committed by GitHub
parent 2f563cee20
commit 3143324984
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 109 additions and 0 deletions

View File

@ -1109,6 +1109,23 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
from langchain.schema.runnable.passthrough import RunnableAssign
if isinstance(self.first, RunnableAssign):
first = cast(RunnableAssign, self.first)
next_ = self.middle[0] if self.middle else self.last
next_input_schema = next_.get_input_schema(config)
if not next_input_schema.__custom_root_type__:
# it's a dict as expected
return create_model( # type: ignore[call-overload]
"RunnableSequenceInput",
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
},
)
return self.first.get_input_schema(config)
def get_output_schema(
@ -2152,6 +2169,9 @@ class RunnableLambda(Runnable[Input, Output]):
else:
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
if self.InputType != Any:
return super().get_input_schema(config)
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
"RunnableLambdaInput",

View File

@ -326,6 +326,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
for k, v in s.__fields__.items()
},
)
elif not map_output_schema.__custom_root_type__:
# ie. only map output is a dict
# ie. input type is either unknown or inferred incorrectly
return map_output_schema
return super().get_output_schema(config)

View File

@ -18,6 +18,7 @@ import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from typing_extensions import TypedDict
from langchain.callbacks.manager import (
Callbacks,
@ -508,6 +509,41 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
}
def test_passthrough_assign_schema() -> None:
retriever = FakeRetriever() # str -> List[Document]
prompt = PromptTemplate.from_template("{context} {question}")
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
seq_w_assign: Runnable = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| prompt
| fake_llm
)
assert seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question", "type": "string"}},
"title": "RunnableSequenceInput",
"type": "object",
}
assert seq_w_assign.output_schema.schema() == {
"title": "FakeListLLMOutput",
"type": "string",
}
invalid_seq_w_assign: Runnable = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| fake_llm
)
# fallback to RunnableAssign.input_schema if next runnable doesn't have
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableParallelInput",
"type": "object",
}
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
@ -565,6 +601,55 @@ def test_lambda_schemas() -> None:
},
}
class InputType(TypedDict):
variable_name: str
yo: int
class OutputType(TypedDict):
hello: str
bye: str
byebye: int
async def aget_values_typed(input: InputType) -> OutputType:
return {
"hello": input["variable_name"],
"bye": input["variable_name"],
"byebye": input["yo"],
}
assert RunnableLambda(aget_values_typed).input_schema.schema() == { # type: ignore[arg-type]
"title": "RunnableLambdaInput",
"$ref": "#/definitions/InputType",
"definitions": {
"InputType": {
"properties": {
"variable_name": {"title": "Variable " "Name", "type": "string"},
"yo": {"title": "Yo", "type": "integer"},
},
"required": ["variable_name", "yo"],
"title": "InputType",
"type": "object",
}
},
}
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
"title": "RunnableLambdaOutput",
"$ref": "#/definitions/OutputType",
"definitions": {
"OutputType": {
"properties": {
"bye": {"title": "Bye", "type": "string"},
"byebye": {"title": "Byebye", "type": "integer"},
"hello": {"title": "Hello", "type": "string"},
},
"required": ["hello", "bye", "byebye"],
"title": "OutputType",
"type": "object",
}
},
}
def test_with_types_with_type_generics() -> None:
"""Verify that with_types works if we use things like List[int]"""