mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 18:23:59 +00:00
Improve Runnable type inference for input_schemas (#12630)
- Prefer lambda type annotations over inferred dict schema - For sequences that start with RunnableAssign infer seq input type as "input type of 2nd item in sequence - output type of runnable assign"
This commit is contained in:
parent
2f563cee20
commit
3143324984
@ -1109,6 +1109,23 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
def get_input_schema(
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
from langchain.schema.runnable.passthrough import RunnableAssign
|
||||
|
||||
if isinstance(self.first, RunnableAssign):
|
||||
first = cast(RunnableAssign, self.first)
|
||||
next_ = self.middle[0] if self.middle else self.last
|
||||
next_input_schema = next_.get_input_schema(config)
|
||||
if not next_input_schema.__custom_root_type__:
|
||||
# it's a dict as expected
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableSequenceInput",
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for k, v in next_input_schema.__fields__.items()
|
||||
if k not in first.mapper.steps
|
||||
},
|
||||
)
|
||||
|
||||
return self.first.get_input_schema(config)
|
||||
|
||||
def get_output_schema(
|
||||
@ -2152,6 +2169,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
else:
|
||||
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
||||
|
||||
if self.InputType != Any:
|
||||
return super().get_input_schema(config)
|
||||
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
|
@ -326,6 +326,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
||||
for k, v in s.__fields__.items()
|
||||
},
|
||||
)
|
||||
elif not map_output_schema.__custom_root_type__:
|
||||
# ie. only map output is a dict
|
||||
# ie. input type is either unknown or inferred incorrectly
|
||||
return map_output_schema
|
||||
|
||||
return super().get_output_schema(config)
|
||||
|
||||
|
@ -18,6 +18,7 @@ import pytest
|
||||
from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy import SnapshotAssertion
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
Callbacks,
|
||||
@ -508,6 +509,41 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_passthrough_assign_schema() -> None:
|
||||
retriever = FakeRetriever() # str -> List[Document]
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||
|
||||
seq_w_assign: Runnable = (
|
||||
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||
| prompt
|
||||
| fake_llm
|
||||
)
|
||||
|
||||
assert seq_w_assign.input_schema.schema() == {
|
||||
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||
"title": "RunnableSequenceInput",
|
||||
"type": "object",
|
||||
}
|
||||
assert seq_w_assign.output_schema.schema() == {
|
||||
"title": "FakeListLLMOutput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
invalid_seq_w_assign: Runnable = (
|
||||
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||
| fake_llm
|
||||
)
|
||||
|
||||
# fallback to RunnableAssign.input_schema if next runnable doesn't have
|
||||
# expected dict input_schema
|
||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableParallelInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||
)
|
||||
@ -565,6 +601,55 @@ def test_lambda_schemas() -> None:
|
||||
},
|
||||
}
|
||||
|
||||
class InputType(TypedDict):
|
||||
variable_name: str
|
||||
yo: int
|
||||
|
||||
class OutputType(TypedDict):
|
||||
hello: str
|
||||
bye: str
|
||||
byebye: int
|
||||
|
||||
async def aget_values_typed(input: InputType) -> OutputType:
|
||||
return {
|
||||
"hello": input["variable_name"],
|
||||
"bye": input["variable_name"],
|
||||
"byebye": input["yo"],
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values_typed).input_schema.schema() == { # type: ignore[arg-type]
|
||||
"title": "RunnableLambdaInput",
|
||||
"$ref": "#/definitions/InputType",
|
||||
"definitions": {
|
||||
"InputType": {
|
||||
"properties": {
|
||||
"variable_name": {"title": "Variable " "Name", "type": "string"},
|
||||
"yo": {"title": "Yo", "type": "integer"},
|
||||
},
|
||||
"required": ["variable_name", "yo"],
|
||||
"title": "InputType",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
|
||||
"title": "RunnableLambdaOutput",
|
||||
"$ref": "#/definitions/OutputType",
|
||||
"definitions": {
|
||||
"OutputType": {
|
||||
"properties": {
|
||||
"bye": {"title": "Bye", "type": "string"},
|
||||
"byebye": {"title": "Byebye", "type": "integer"},
|
||||
"hello": {"title": "Hello", "type": "string"},
|
||||
},
|
||||
"required": ["hello", "bye", "byebye"],
|
||||
"title": "OutputType",
|
||||
"type": "object",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_with_types_with_type_generics() -> None:
|
||||
"""Verify that with_types works if we use things like List[int]"""
|
||||
|
Loading…
Reference in New Issue
Block a user