mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-25 12:33:39 +00:00
Implement nicer runnable seq constructor, Propagate name through Runn… (#15226)
…ableBinding <!-- Thank you for contributing to LangChain! Please title your PR "<package>: <description>", where <package> is whichever of langchain, community, core, experimental, etc. is being modified. Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes if applicable, - **Dependencies:** any dependencies required for this change, - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` from the root of the package you've modified to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://python.langchain.com/docs/contributing/ If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. -->
This commit is contained in:
parent
f36ef0739d
commit
0252a24471
@ -217,6 +217,20 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
|
||||
"""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name of the runnable. Used for debugging and tracing."""
|
||||
|
||||
def get_name(self, suffix: Optional[str] = None) -> str:
|
||||
"""Get the name of the runnable."""
|
||||
name = self.name or self.__class__.__name__
|
||||
if suffix:
|
||||
if name[0].isupper():
|
||||
return name + suffix.title()
|
||||
else:
|
||||
return name + "_" + suffix.lower()
|
||||
else:
|
||||
return name
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
"""The type of input this runnable accepts specified as a type annotation."""
|
||||
@ -226,7 +240,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return type_args[0]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
|
||||
f"Runnable {self.get_name()} doesn't have an inferable InputType. "
|
||||
"Override the InputType property to specify the input type."
|
||||
)
|
||||
|
||||
@ -239,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return type_args[1]
|
||||
|
||||
raise TypeError(
|
||||
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
|
||||
f"Runnable {self.get_name()} doesn't have an inferable OutputType. "
|
||||
"Override the OutputType property to specify the output type."
|
||||
)
|
||||
|
||||
@ -271,7 +285,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Input",
|
||||
self.get_name("Input"),
|
||||
__root__=(root_type, None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -304,7 +318,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
return root_type
|
||||
|
||||
return create_model(
|
||||
self.__class__.__name__ + "Output",
|
||||
self.get_name("Output"),
|
||||
__root__=(root_type, None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -350,7 +364,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
)
|
||||
|
||||
return create_model( # type: ignore[call-overload]
|
||||
self.__class__.__name__ + "Config",
|
||||
self.get_name("Config"),
|
||||
__config__=_SchemaConfig,
|
||||
**({"configurable": (configurable, None)} if configurable else {}),
|
||||
**{
|
||||
@ -382,7 +396,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
],
|
||||
) -> RunnableSerializable[Input, Other]:
|
||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||
return RunnableSequence(self, coerce_to_runnable(other))
|
||||
|
||||
def __ror__(
|
||||
self,
|
||||
@ -394,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
],
|
||||
) -> RunnableSerializable[Other, Output]:
|
||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||
return RunnableSequence(coerce_to_runnable(other), self)
|
||||
|
||||
""" --- Public API --- """
|
||||
|
||||
@ -900,7 +914,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
try:
|
||||
output = call_func_with_variable_args(
|
||||
@ -936,7 +950,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
try:
|
||||
output = await acall_func_with_variable_args(
|
||||
@ -981,7 +995,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
@ -1053,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
input,
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
for callback_manager, input, config in zip(
|
||||
callback_managers, input, configs
|
||||
@ -1128,7 +1142,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
try:
|
||||
if accepts_config(transformer):
|
||||
@ -1204,7 +1218,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
dumpd(self),
|
||||
{"input": ""},
|
||||
run_type=run_type,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.get_name(),
|
||||
)
|
||||
try:
|
||||
if accepts_config(transformer):
|
||||
@ -1245,6 +1259,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
class RunnableSerializable(Serializable, Runnable[Input, Output]):
|
||||
"""A Runnable that can be serialized to JSON."""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""The name of the runnable. Used for debugging and tracing."""
|
||||
|
||||
def configurable_fields(
|
||||
self, **kwargs: AnyConfigurableField
|
||||
) -> RunnableSerializable[Input, Output]:
|
||||
@ -1448,6 +1465,39 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
last: Runnable[Any, Output]
|
||||
"""The last runnable in the sequence."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*steps: RunnableLike,
|
||||
name: Optional[str] = None,
|
||||
first: Optional[Runnable[Any, Any]] = None,
|
||||
middle: Optional[List[Runnable[Any, Any]]] = None,
|
||||
last: Optional[Runnable[Any, Any]] = None,
|
||||
) -> None:
|
||||
"""Create a new RunnableSequence.
|
||||
|
||||
Args:
|
||||
steps: The steps to include in the sequence.
|
||||
"""
|
||||
steps_flat: List[Runnable] = []
|
||||
if not steps:
|
||||
if first is not None and last is not None:
|
||||
steps_flat = [first] + (middle or []) + [last]
|
||||
for step in steps:
|
||||
if isinstance(step, RunnableSequence):
|
||||
steps_flat.extend(step.steps)
|
||||
else:
|
||||
steps_flat.append(coerce_to_runnable(step))
|
||||
if len(steps_flat) < 2:
|
||||
raise ValueError(
|
||||
f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
|
||||
)
|
||||
super().__init__(
|
||||
first=steps_flat[0],
|
||||
middle=list(steps_flat[1:-1]),
|
||||
last=steps_flat[-1],
|
||||
name=name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
"""Get the namespace of the langchain object."""
|
||||
@ -1566,15 +1616,21 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
) -> RunnableSerializable[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last] + [other.first] + other.middle,
|
||||
last=other.last,
|
||||
self.first,
|
||||
*self.middle,
|
||||
self.last,
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
name=self.name or other.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
middle=self.middle + [self.last],
|
||||
last=coerce_to_runnable(other),
|
||||
self.first,
|
||||
*self.middle,
|
||||
self.last,
|
||||
coerce_to_runnable(other),
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def __ror__(
|
||||
@ -1588,15 +1644,21 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
) -> RunnableSerializable[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=other.first,
|
||||
middle=other.middle + [other.last] + [self.first] + self.middle,
|
||||
last=self.last,
|
||||
other.first,
|
||||
*other.middle,
|
||||
other.last,
|
||||
self.first,
|
||||
*self.middle,
|
||||
self.last,
|
||||
name=other.name or self.name,
|
||||
)
|
||||
else:
|
||||
return RunnableSequence(
|
||||
first=coerce_to_runnable(other),
|
||||
middle=[self.first] + self.middle,
|
||||
last=self.last,
|
||||
coerce_to_runnable(other),
|
||||
self.first,
|
||||
*self.middle,
|
||||
self.last,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
|
||||
@ -1607,7 +1669,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name") or self.name
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -1641,7 +1703,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
callback_manager = get_async_callback_manager_for_config(config)
|
||||
# start the root run
|
||||
run_manager = await callback_manager.on_chain_start(
|
||||
dumpd(self), input, name=config.get("run_name")
|
||||
dumpd(self), input, name=config.get("run_name") or self.name
|
||||
)
|
||||
|
||||
# invoke all steps in sequence
|
||||
@ -1698,7 +1760,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.name,
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
]
|
||||
@ -1822,7 +1884,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
cm.on_chain_start(
|
||||
dumpd(self),
|
||||
input,
|
||||
name=config.get("run_name"),
|
||||
name=config.get("run_name") or self.name,
|
||||
)
|
||||
for cm, input, config in zip(callback_managers, inputs, configs)
|
||||
)
|
||||
@ -1972,7 +2034,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> Iterator[Output]:
|
||||
yield from self._transform_stream_with_config(
|
||||
input, self._transform, config, **kwargs
|
||||
input,
|
||||
self._transform,
|
||||
patch_config(config, run_name=(config or {}).get("run_name") or self.name),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def stream(
|
||||
@ -1990,7 +2055,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
**kwargs: Optional[Any],
|
||||
) -> AsyncIterator[Output]:
|
||||
async for chunk in self._atransform_stream_with_config(
|
||||
input, self._atransform, config, **kwargs
|
||||
input,
|
||||
self._atransform,
|
||||
patch_config(config, run_name=(config or {}).get("run_name") or self.name),
|
||||
**kwargs,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
@ -2068,7 +2136,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
):
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapInput",
|
||||
self.get_name("Input"),
|
||||
**{
|
||||
k: (v.annotation, v.default)
|
||||
for step in self.steps.values()
|
||||
@ -2085,7 +2153,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
|
||||
) -> Type[BaseModel]:
|
||||
# This is correct, but pydantic typings/mypy don't think so.
|
||||
return create_model( # type: ignore[call-overload]
|
||||
"RunnableMapOutput",
|
||||
self.get_name("Output"),
|
||||
**{k: (v.OutputType, None) for k, v in self.steps.items()},
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -2576,6 +2644,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
"""
|
||||
if afunc is not None:
|
||||
self.afunc = afunc
|
||||
func_for_name: Callable = afunc
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
if afunc is not None:
|
||||
@ -2585,14 +2654,22 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
"function to avoid ambiguity."
|
||||
)
|
||||
self.afunc = func
|
||||
func_for_name = func
|
||||
elif callable(func):
|
||||
self.func = cast(Callable[[Input], Output], func)
|
||||
func_for_name = func
|
||||
else:
|
||||
raise TypeError(
|
||||
"Expected a callable type for `func`."
|
||||
f"Instead got an unsupported type: {type(func)}"
|
||||
)
|
||||
|
||||
try:
|
||||
if func_for_name.__name__ != "<lambda>":
|
||||
self.name = func_for_name.__name__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def InputType(self) -> Any:
|
||||
"""The type of the input to this runnable."""
|
||||
@ -2622,13 +2699,13 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
):
|
||||
# It's a dict, lol
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
self.get_name("Input"),
|
||||
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
else:
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
self.get_name("Input"),
|
||||
__root__=(List[Any], None),
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -2638,7 +2715,7 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
|
||||
if dict_keys := get_function_first_arg_dict_keys(func):
|
||||
return create_model(
|
||||
"RunnableLambdaInput",
|
||||
self.get_name("Input"),
|
||||
**{key: (Any, None) for key in dict_keys}, # type: ignore
|
||||
__config__=_SchemaConfig,
|
||||
)
|
||||
@ -3012,7 +3089,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
self, config: Optional[RunnableConfig] = None
|
||||
) -> Type[BaseModel]:
|
||||
return create_model(
|
||||
"RunnableEachInput",
|
||||
self.get_name("Input"),
|
||||
__root__=(
|
||||
List[self.bound.get_input_schema(config)], # type: ignore
|
||||
None,
|
||||
@ -3029,7 +3106,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
|
||||
) -> Type[BaseModel]:
|
||||
schema = self.bound.get_output_schema(config)
|
||||
return create_model(
|
||||
"RunnableEachOutput",
|
||||
self.get_name("Output"),
|
||||
__root__=(
|
||||
List[schema], # type: ignore
|
||||
None,
|
||||
@ -3221,6 +3298,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
**other_kwargs,
|
||||
)
|
||||
|
||||
def get_name(self, suffix: Optional[str] = None) -> str:
|
||||
return self.bound.get_name(suffix)
|
||||
|
||||
@property
|
||||
def InputType(self) -> Type[Input]:
|
||||
return (
|
||||
|
@ -32,51 +32,51 @@
|
||||
# ---
|
||||
# name: test_graph_sequence_map
|
||||
'''
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+---------------+
|
||||
| ParallelInput |
|
||||
+---------------+*****
|
||||
*** ******
|
||||
*** *****
|
||||
** *****
|
||||
+-------------+ ***
|
||||
| LambdaInput | *
|
||||
+-------------+ *
|
||||
** ** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
** ** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+--------------+ +--------------------------------+
|
||||
| LambdaOutput | | CommaSeparatedListOutputParser |
|
||||
+--------------+ +--------------------------------+
|
||||
*** ******
|
||||
*** *****
|
||||
** ***
|
||||
+-----------+
|
||||
| MapOutput |
|
||||
+-----------+
|
||||
+-------------+
|
||||
| PromptInput |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+----------------+
|
||||
| PromptTemplate |
|
||||
+----------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+-------------+
|
||||
| FakeListLLM |
|
||||
+-------------+
|
||||
*
|
||||
*
|
||||
*
|
||||
+---------------+
|
||||
| ParallelInput |
|
||||
+---------------+******
|
||||
***** ******
|
||||
*** ******
|
||||
*** ******
|
||||
+------------------------------+ ***
|
||||
| conditional_str_parser_input | *
|
||||
+------------------------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-----------------+ +-----------------+ *
|
||||
| StrOutputParser | | XMLOutputParser | *
|
||||
+-----------------+ +-----------------+ *
|
||||
*** *** *
|
||||
*** *** *
|
||||
** ** *
|
||||
+-------------------------------+ +--------------------------------+
|
||||
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
|
||||
+-------------------------------+ +--------------------------------+
|
||||
***** ******
|
||||
*** ******
|
||||
*** ***
|
||||
+----------------+
|
||||
| ParallelOutput |
|
||||
+----------------+
|
||||
'''
|
||||
# ---
|
||||
# name: test_graph_single_runnable
|
||||
|
File diff suppressed because one or more lines are too long
@ -200,11 +200,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
|
||||
|
||||
assert typed_lambda.input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "typed_lambda_impl_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_lambda.output_schema.schema() == {
|
||||
"title": "RunnableLambdaOutput",
|
||||
"title": "typed_lambda_impl_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
@ -214,11 +214,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
|
||||
|
||||
assert typed_async_lambda.input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "typed_async_lambda_impl_input",
|
||||
"type": "string",
|
||||
}
|
||||
assert typed_async_lambda.output_schema.schema() == {
|
||||
"title": "RunnableLambdaOutput",
|
||||
"title": "typed_async_lambda_impl_output",
|
||||
"type": "integer",
|
||||
}
|
||||
|
||||
@ -571,7 +571,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
|
||||
"properties": {"name": {"title": "Name", "type": "string"}},
|
||||
}
|
||||
assert seq_w_map.output_schema.schema() == {
|
||||
"title": "RunnableMapOutput",
|
||||
"title": "RunnableParallelOutput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"original": {"title": "Original", "type": "string"},
|
||||
@ -615,7 +615,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
# expected dict input_schema
|
||||
assert invalid_seq_w_assign.input_schema.schema() == {
|
||||
"properties": {"question": {"title": "Question"}},
|
||||
"title": "RunnableMapInput",
|
||||
"title": "RunnableParallelInput",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
@ -645,7 +645,7 @@ def test_lambda_schemas() -> None:
|
||||
return input["variable_name"]
|
||||
|
||||
assert RunnableLambda(get_value).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "get_value_input",
|
||||
"type": "object",
|
||||
"properties": {"variable_name": {"title": "Variable Name"}},
|
||||
}
|
||||
@ -654,7 +654,7 @@ def test_lambda_schemas() -> None:
|
||||
return (input["variable_name"], input.get("another"))
|
||||
|
||||
assert RunnableLambda(aget_value).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "aget_value_input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"another": {"title": "Another"},
|
||||
@ -670,7 +670,7 @@ def test_lambda_schemas() -> None:
|
||||
}
|
||||
|
||||
assert RunnableLambda(aget_values).input_schema.schema() == {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "aget_values_input",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variable_name": {"title": "Variable Name"},
|
||||
@ -697,7 +697,7 @@ def test_lambda_schemas() -> None:
|
||||
assert (
|
||||
RunnableLambda(aget_values_typed).input_schema.schema() # type: ignore[arg-type]
|
||||
== {
|
||||
"title": "RunnableLambdaInput",
|
||||
"title": "aget_values_typed_input",
|
||||
"$ref": "#/definitions/InputType",
|
||||
"definitions": {
|
||||
"InputType": {
|
||||
@ -717,7 +717,7 @@ def test_lambda_schemas() -> None:
|
||||
)
|
||||
|
||||
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
|
||||
"title": "RunnableLambdaOutput",
|
||||
"title": "aget_values_typed_output",
|
||||
"$ref": "#/definitions/OutputType",
|
||||
"definitions": {
|
||||
"OutputType": {
|
||||
@ -760,7 +760,11 @@ def test_schema_complex_seq() -> None:
|
||||
|
||||
model = FakeListChatModel(responses=[""])
|
||||
|
||||
chain1 = prompt1 | model | StrOutputParser()
|
||||
chain1: Runnable = RunnableSequence(
|
||||
prompt1, model, StrOutputParser(), name="city_chain"
|
||||
)
|
||||
|
||||
assert chain1.name == "city_chain"
|
||||
|
||||
chain2: Runnable = (
|
||||
{"city": chain1, "language": itemgetter("language")}
|
||||
@ -770,7 +774,7 @@ def test_schema_complex_seq() -> None:
|
||||
)
|
||||
|
||||
assert chain2.input_schema.schema() == {
|
||||
"title": "RunnableMapInput",
|
||||
"title": "RunnableParallelInput",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"person": {"title": "Person", "type": "string"},
|
||||
@ -784,7 +788,7 @@ def test_schema_complex_seq() -> None:
|
||||
}
|
||||
|
||||
assert chain2.with_types(input_type=str).input_schema.schema() == {
|
||||
"title": "RunnableBindingInput",
|
||||
"title": "RunnableSequenceInput",
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user