mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 10:43:36 +00:00
Improve Runnable type inference for input_schemas (#12630)
- Prefer lambda type annotations over inferred dict schema - For sequences that start with RunnableAssign infer seq input type as "input type of 2nd item in sequence - output type of runnable assign"
This commit is contained in:
parent
2f563cee20
commit
3143324984
@ -1109,6 +1109,23 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
def get_input_schema(
|
def get_input_schema(
|
||||||
self, config: Optional[RunnableConfig] = None
|
self, config: Optional[RunnableConfig] = None
|
||||||
) -> Type[BaseModel]:
|
) -> Type[BaseModel]:
|
||||||
|
from langchain.schema.runnable.passthrough import RunnableAssign
|
||||||
|
|
||||||
|
if isinstance(self.first, RunnableAssign):
|
||||||
|
first = cast(RunnableAssign, self.first)
|
||||||
|
next_ = self.middle[0] if self.middle else self.last
|
||||||
|
next_input_schema = next_.get_input_schema(config)
|
||||||
|
if not next_input_schema.__custom_root_type__:
|
||||||
|
# it's a dict as expected
|
||||||
|
return create_model( # type: ignore[call-overload]
|
||||||
|
"RunnableSequenceInput",
|
||||||
|
**{
|
||||||
|
k: (v.annotation, v.default)
|
||||||
|
for k, v in next_input_schema.__fields__.items()
|
||||||
|
if k not in first.mapper.steps
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return self.first.get_input_schema(config)
|
return self.first.get_input_schema(config)
|
||||||
|
|
||||||
def get_output_schema(
|
def get_output_schema(
|
||||||
@ -2152,6 +2169,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
else:
|
else:
|
||||||
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
return create_model("RunnableLambdaInput", __root__=(List[Any], None))
|
||||||
|
|
||||||
|
if self.InputType != Any:
|
||||||
|
return super().get_input_schema(config)
|
||||||
|
|
||||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||||
return create_model(
|
return create_model(
|
||||||
"RunnableLambdaInput",
|
"RunnableLambdaInput",
|
||||||
|
@ -326,6 +326,10 @@ class RunnableAssign(RunnableSerializable[Dict[str, Any], Dict[str, Any]]):
|
|||||||
for k, v in s.__fields__.items()
|
for k, v in s.__fields__.items()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif not map_output_schema.__custom_root_type__:
|
||||||
|
# ie. only map output is a dict
|
||||||
|
# ie. input type is either unknown or inferred incorrectly
|
||||||
|
return map_output_schema
|
||||||
|
|
||||||
return super().get_output_schema(config)
|
return super().get_output_schema(config)
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ import pytest
|
|||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
Callbacks,
|
Callbacks,
|
||||||
@ -508,6 +509,41 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_passthrough_assign_schema() -> None:
|
||||||
|
retriever = FakeRetriever() # str -> List[Document]
|
||||||
|
prompt = PromptTemplate.from_template("{context} {question}")
|
||||||
|
fake_llm = FakeListLLM(responses=["a"]) # str -> List[List[str]]
|
||||||
|
|
||||||
|
seq_w_assign: Runnable = (
|
||||||
|
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||||
|
| prompt
|
||||||
|
| fake_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seq_w_assign.input_schema.schema() == {
|
||||||
|
"properties": {"question": {"title": "Question", "type": "string"}},
|
||||||
|
"title": "RunnableSequenceInput",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
assert seq_w_assign.output_schema.schema() == {
|
||||||
|
"title": "FakeListLLMOutput",
|
||||||
|
"type": "string",
|
||||||
|
}
|
||||||
|
|
||||||
|
invalid_seq_w_assign: Runnable = (
|
||||||
|
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||||
|
| fake_llm
|
||||||
|
)
|
||||||
|
|
||||||
|
# fallback to RunnableAssign.input_schema if next runnable doesn't have
|
||||||
|
# expected dict input_schema
|
||||||
|
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||||
|
"properties": {"question": {"title": "Question"}},
|
||||||
|
"title": "RunnableParallelInput",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
|
||||||
)
|
)
|
||||||
@ -565,6 +601,55 @@ def test_lambda_schemas() -> None:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class InputType(TypedDict):
|
||||||
|
variable_name: str
|
||||||
|
yo: int
|
||||||
|
|
||||||
|
class OutputType(TypedDict):
|
||||||
|
hello: str
|
||||||
|
bye: str
|
||||||
|
byebye: int
|
||||||
|
|
||||||
|
async def aget_values_typed(input: InputType) -> OutputType:
|
||||||
|
return {
|
||||||
|
"hello": input["variable_name"],
|
||||||
|
"bye": input["variable_name"],
|
||||||
|
"byebye": input["yo"],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert RunnableLambda(aget_values_typed).input_schema.schema() == { # type: ignore[arg-type]
|
||||||
|
"title": "RunnableLambdaInput",
|
||||||
|
"$ref": "#/definitions/InputType",
|
||||||
|
"definitions": {
|
||||||
|
"InputType": {
|
||||||
|
"properties": {
|
||||||
|
"variable_name": {"title": "Variable " "Name", "type": "string"},
|
||||||
|
"yo": {"title": "Yo", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["variable_name", "yo"],
|
||||||
|
"title": "InputType",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
|
||||||
|
"title": "RunnableLambdaOutput",
|
||||||
|
"$ref": "#/definitions/OutputType",
|
||||||
|
"definitions": {
|
||||||
|
"OutputType": {
|
||||||
|
"properties": {
|
||||||
|
"bye": {"title": "Bye", "type": "string"},
|
||||||
|
"byebye": {"title": "Byebye", "type": "integer"},
|
||||||
|
"hello": {"title": "Hello", "type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["hello", "bye", "byebye"],
|
||||||
|
"title": "OutputType",
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_with_types_with_type_generics() -> None:
|
def test_with_types_with_type_generics() -> None:
|
||||||
"""Verify that with_types works if we use things like List[int]"""
|
"""Verify that with_types works if we use things like List[int]"""
|
||||||
|
Loading…
Reference in New Issue
Block a user