mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
Relax type annotation for custom input/output types (#12300)
This is needed to be able to do stuff like: ```python runnable.with_types(input_type=List[str]) ```
This commit is contained in:
parent
988f6d9912
commit
5a71b81609
@ -2367,9 +2367,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
|
||||
config: RunnableConfig = Field(default_factory=dict)
|
||||
|
||||
custom_input_type: Optional[Union[Type[Input], BaseModel]] = None
|
||||
|
||||
custom_output_type: Optional[Union[Type[Output], BaseModel]] = None
|
||||
# Union[Type[Input], BaseModel] + things like List[str]
|
||||
custom_input_type: Optional[Any] = None
|
||||
# Union[Type[Output], BaseModel] + things like List[str]
|
||||
custom_output_type: Optional[Any] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -557,6 +557,22 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_with_types_with_type_generics() -> None:
|
||||
"""Verify that with_types works if we use things like List[int]"""
|
||||
|
||||
def foo(x: int) -> None:
|
||||
"""Add one to the input."""
|
||||
raise NotImplementedError()
|
||||
|
||||
# Try specifying some
|
||||
RunnableLambda(foo).with_types(
|
||||
output_type=List[int], input_type=List[int] # type: ignore
|
||||
)
|
||||
RunnableLambda(foo).with_types(
|
||||
output_type=Sequence[int], input_type=Sequence[int] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def test_schema_complex_seq() -> None:
|
||||
prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?")
|
||||
prompt2 = ChatPromptTemplate.from_template(
|
||||
|
Loading…
Reference in New Issue
Block a user