Implement nicer runnable seq constructor, Propagate name through Runn… (#15226)

…ableBinding

<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
This commit is contained in:
Nuno Campos 2023-12-27 11:24:32 -08:00 committed by GitHub
parent f36ef0739d
commit 0252a24471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 243 additions and 126 deletions

View File

@ -217,6 +217,20 @@ class Runnable(Generic[Input, Output], ABC):
For a UI (and much more) checkout LangSmith: https://docs.smith.langchain.com/
"""
name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing."""
def get_name(self, suffix: Optional[str] = None) -> str:
"""Get the name of the runnable."""
name = self.name or self.__class__.__name__
if suffix:
if name[0].isupper():
return name + suffix.title()
else:
return name + "_" + suffix.lower()
else:
return name
@property
def InputType(self) -> Type[Input]:
"""The type of input this runnable accepts specified as a type annotation."""
@ -226,7 +240,7 @@ class Runnable(Generic[Input, Output], ABC):
return type_args[0]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable InputType. "
f"Runnable {self.get_name()} doesn't have an inferable InputType. "
"Override the InputType property to specify the input type."
)
@ -239,7 +253,7 @@ class Runnable(Generic[Input, Output], ABC):
return type_args[1]
raise TypeError(
f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. "
f"Runnable {self.get_name()} doesn't have an inferable OutputType. "
"Override the OutputType property to specify the output type."
)
@ -271,7 +285,7 @@ class Runnable(Generic[Input, Output], ABC):
return root_type
return create_model(
self.__class__.__name__ + "Input",
self.get_name("Input"),
__root__=(root_type, None),
__config__=_SchemaConfig,
)
@ -304,7 +318,7 @@ class Runnable(Generic[Input, Output], ABC):
return root_type
return create_model(
self.__class__.__name__ + "Output",
self.get_name("Output"),
__root__=(root_type, None),
__config__=_SchemaConfig,
)
@ -350,7 +364,7 @@ class Runnable(Generic[Input, Output], ABC):
)
return create_model( # type: ignore[call-overload]
self.__class__.__name__ + "Config",
self.get_name("Config"),
__config__=_SchemaConfig,
**({"configurable": (configurable, None)} if configurable else {}),
**{
@ -382,7 +396,7 @@ class Runnable(Generic[Input, Output], ABC):
],
) -> RunnableSerializable[Input, Other]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=self, last=coerce_to_runnable(other))
return RunnableSequence(self, coerce_to_runnable(other))
def __ror__(
self,
@ -394,7 +408,7 @@ class Runnable(Generic[Input, Output], ABC):
],
) -> RunnableSerializable[Other, Output]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=coerce_to_runnable(other), last=self)
return RunnableSequence(coerce_to_runnable(other), self)
""" --- Public API --- """
@ -900,7 +914,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
try:
output = call_func_with_variable_args(
@ -936,7 +950,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
try:
output = await acall_func_with_variable_args(
@ -981,7 +995,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1053,7 +1067,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
for callback_manager, input, config in zip(
callback_managers, input, configs
@ -1128,7 +1142,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
try:
if accepts_config(transformer):
@ -1204,7 +1218,7 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
name=config.get("run_name"),
name=config.get("run_name") or self.get_name(),
)
try:
if accepts_config(transformer):
@ -1245,6 +1259,9 @@ class Runnable(Generic[Input, Output], ABC):
class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""A Runnable that can be serialized to JSON."""
name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing."""
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
@ -1448,6 +1465,39 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
last: Runnable[Any, Output]
"""The last runnable in the sequence."""
def __init__(
self,
*steps: RunnableLike,
name: Optional[str] = None,
first: Optional[Runnable[Any, Any]] = None,
middle: Optional[List[Runnable[Any, Any]]] = None,
last: Optional[Runnable[Any, Any]] = None,
) -> None:
"""Create a new RunnableSequence.
Args:
steps: The steps to include in the sequence.
"""
steps_flat: List[Runnable] = []
if not steps:
if first is not None and last is not None:
steps_flat = [first] + (middle or []) + [last]
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
else:
steps_flat.append(coerce_to_runnable(step))
if len(steps_flat) < 2:
raise ValueError(
f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
)
super().__init__(
first=steps_flat[0],
middle=list(steps_flat[1:-1]),
last=steps_flat[-1],
name=name,
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
@ -1566,15 +1616,21 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) -> RunnableSerializable[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last] + [other.first] + other.middle,
last=other.last,
self.first,
*self.middle,
self.last,
other.first,
*other.middle,
other.last,
name=self.name or other.name,
)
else:
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last],
last=coerce_to_runnable(other),
self.first,
*self.middle,
self.last,
coerce_to_runnable(other),
name=self.name,
)
def __ror__(
@ -1588,15 +1644,21 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
) -> RunnableSerializable[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
middle=other.middle + [other.last] + [self.first] + self.middle,
last=self.last,
other.first,
*other.middle,
other.last,
self.first,
*self.middle,
self.last,
name=other.name or self.name,
)
else:
return RunnableSequence(
first=coerce_to_runnable(other),
middle=[self.first] + self.middle,
last=self.last,
coerce_to_runnable(other),
self.first,
*self.middle,
self.last,
name=self.name,
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
@ -1607,7 +1669,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self), input, name=config.get("run_name") or self.name
)
# invoke all steps in sequence
@ -1641,7 +1703,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, name=config.get("run_name")
dumpd(self), input, name=config.get("run_name") or self.name
)
# invoke all steps in sequence
@ -1698,7 +1760,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
name=config.get("run_name") or self.name,
)
for cm, input, config in zip(callback_managers, inputs, configs)
]
@ -1822,7 +1884,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input,
name=config.get("run_name"),
name=config.get("run_name") or self.name,
)
for cm, input, config in zip(callback_managers, inputs, configs)
)
@ -1972,7 +2034,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any],
) -> Iterator[Output]:
yield from self._transform_stream_with_config(
input, self._transform, config, **kwargs
input,
self._transform,
patch_config(config, run_name=(config or {}).get("run_name") or self.name),
**kwargs,
)
def stream(
@ -1990,7 +2055,10 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
**kwargs: Optional[Any],
) -> AsyncIterator[Output]:
async for chunk in self._atransform_stream_with_config(
input, self._atransform, config, **kwargs
input,
self._atransform,
patch_config(config, run_name=(config or {}).get("run_name") or self.name),
**kwargs,
):
yield chunk
@ -2068,7 +2136,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapInput",
self.get_name("Input"),
**{
k: (v.annotation, v.default)
for step in self.steps.values()
@ -2085,7 +2153,7 @@ class RunnableParallel(RunnableSerializable[Input, Dict[str, Any]]):
) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableMapOutput",
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps.items()},
__config__=_SchemaConfig,
)
@ -2576,6 +2644,7 @@ class RunnableLambda(Runnable[Input, Output]):
"""
if afunc is not None:
self.afunc = afunc
func_for_name: Callable = afunc
if inspect.iscoroutinefunction(func):
if afunc is not None:
@ -2585,14 +2654,22 @@ class RunnableLambda(Runnable[Input, Output]):
"function to avoid ambiguity."
)
self.afunc = func
func_for_name = func
elif callable(func):
self.func = cast(Callable[[Input], Output], func)
func_for_name = func
else:
raise TypeError(
"Expected a callable type for `func`."
f"Instead got an unsupported type: {type(func)}"
)
try:
if func_for_name.__name__ != "<lambda>":
self.name = func_for_name.__name__
except AttributeError:
pass
@property
def InputType(self) -> Any:
"""The type of the input to this runnable."""
@ -2622,13 +2699,13 @@ class RunnableLambda(Runnable[Input, Output]):
):
# It's a dict, lol
return create_model(
"RunnableLambdaInput",
self.get_name("Input"),
**{item[1:-1]: (Any, None) for item in items}, # type: ignore
__config__=_SchemaConfig,
)
else:
return create_model(
"RunnableLambdaInput",
self.get_name("Input"),
__root__=(List[Any], None),
__config__=_SchemaConfig,
)
@ -2638,7 +2715,7 @@ class RunnableLambda(Runnable[Input, Output]):
if dict_keys := get_function_first_arg_dict_keys(func):
return create_model(
"RunnableLambdaInput",
self.get_name("Input"),
**{key: (Any, None) for key in dict_keys}, # type: ignore
__config__=_SchemaConfig,
)
@ -3012,7 +3089,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return create_model(
"RunnableEachInput",
self.get_name("Input"),
__root__=(
List[self.bound.get_input_schema(config)], # type: ignore
None,
@ -3029,7 +3106,7 @@ class RunnableEachBase(RunnableSerializable[List[Input], List[Output]]):
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
return create_model(
"RunnableEachOutput",
self.get_name("Output"),
__root__=(
List[schema], # type: ignore
None,
@ -3221,6 +3298,9 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
**other_kwargs,
)
def get_name(self, suffix: Optional[str] = None) -> str:
return self.bound.get_name(suffix)
@property
def InputType(self) -> Type[Input]:
return (

View File

@ -32,51 +32,51 @@
# ---
# name: test_graph_sequence_map
'''
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+---------------+
| ParallelInput |
+---------------+*****
*** ******
*** *****
** *****
+-------------+ ***
| LambdaInput | *
+-------------+ *
** ** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
** ** *
*** *** *
** ** *
+--------------+ +--------------------------------+
| LambdaOutput | | CommaSeparatedListOutputParser |
+--------------+ +--------------------------------+
*** ******
*** *****
** ***
+-----------+
| MapOutput |
+-----------+
+-------------+
| PromptInput |
+-------------+
*
*
*
+----------------+
| PromptTemplate |
+----------------+
*
*
*
+-------------+
| FakeListLLM |
+-------------+
*
*
*
+---------------+
| ParallelInput |
+---------------+******
***** ******
*** ******
*** ******
+------------------------------+ ***
| conditional_str_parser_input | *
+------------------------------+ *
*** *** *
*** *** *
** ** *
+-----------------+ +-----------------+ *
| StrOutputParser | | XMLOutputParser | *
+-----------------+ +-----------------+ *
*** *** *
*** *** *
** ** *
+-------------------------------+ +--------------------------------+
| conditional_str_parser_output | | CommaSeparatedListOutputParser |
+-------------------------------+ +--------------------------------+
***** ******
*** ******
*** ***
+----------------+
| ParallelOutput |
+----------------+
'''
# ---
# name: test_graph_single_runnable

File diff suppressed because one or more lines are too long

View File

@ -200,11 +200,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
typed_lambda = RunnableLambda(typed_lambda_impl) # str -> int
assert typed_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"title": "typed_lambda_impl_input",
"type": "string",
}
assert typed_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"title": "typed_lambda_impl_output",
"type": "integer",
}
@ -214,11 +214,11 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
typed_async_lambda: Runnable = RunnableLambda(typed_async_lambda_impl) # str -> int
assert typed_async_lambda.input_schema.schema() == {
"title": "RunnableLambdaInput",
"title": "typed_async_lambda_impl_input",
"type": "string",
}
assert typed_async_lambda.output_schema.schema() == {
"title": "RunnableLambdaOutput",
"title": "typed_async_lambda_impl_output",
"type": "integer",
}
@ -571,7 +571,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"properties": {"name": {"title": "Name", "type": "string"}},
}
assert seq_w_map.output_schema.schema() == {
"title": "RunnableMapOutput",
"title": "RunnableParallelOutput",
"type": "object",
"properties": {
"original": {"title": "Original", "type": "string"},
@ -615,7 +615,7 @@ def test_passthrough_assign_schema() -> None:
# expected dict input_schema
assert invalid_seq_w_assign.input_schema.schema() == {
"properties": {"question": {"title": "Question"}},
"title": "RunnableMapInput",
"title": "RunnableParallelInput",
"type": "object",
}
@ -645,7 +645,7 @@ def test_lambda_schemas() -> None:
return input["variable_name"]
assert RunnableLambda(get_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"title": "get_value_input",
"type": "object",
"properties": {"variable_name": {"title": "Variable Name"}},
}
@ -654,7 +654,7 @@ def test_lambda_schemas() -> None:
return (input["variable_name"], input.get("another"))
assert RunnableLambda(aget_value).input_schema.schema() == {
"title": "RunnableLambdaInput",
"title": "aget_value_input",
"type": "object",
"properties": {
"another": {"title": "Another"},
@ -670,7 +670,7 @@ def test_lambda_schemas() -> None:
}
assert RunnableLambda(aget_values).input_schema.schema() == {
"title": "RunnableLambdaInput",
"title": "aget_values_input",
"type": "object",
"properties": {
"variable_name": {"title": "Variable Name"},
@ -697,7 +697,7 @@ def test_lambda_schemas() -> None:
assert (
RunnableLambda(aget_values_typed).input_schema.schema() # type: ignore[arg-type]
== {
"title": "RunnableLambdaInput",
"title": "aget_values_typed_input",
"$ref": "#/definitions/InputType",
"definitions": {
"InputType": {
@ -717,7 +717,7 @@ def test_lambda_schemas() -> None:
)
assert RunnableLambda(aget_values_typed).output_schema.schema() == { # type: ignore[arg-type]
"title": "RunnableLambdaOutput",
"title": "aget_values_typed_output",
"$ref": "#/definitions/OutputType",
"definitions": {
"OutputType": {
@ -760,7 +760,11 @@ def test_schema_complex_seq() -> None:
model = FakeListChatModel(responses=[""])
chain1 = prompt1 | model | StrOutputParser()
chain1: Runnable = RunnableSequence(
prompt1, model, StrOutputParser(), name="city_chain"
)
assert chain1.name == "city_chain"
chain2: Runnable = (
{"city": chain1, "language": itemgetter("language")}
@ -770,7 +774,7 @@ def test_schema_complex_seq() -> None:
)
assert chain2.input_schema.schema() == {
"title": "RunnableMapInput",
"title": "RunnableParallelInput",
"type": "object",
"properties": {
"person": {"title": "Person", "type": "string"},
@ -784,7 +788,7 @@ def test_schema_complex_seq() -> None:
}
assert chain2.with_types(input_type=str).input_schema.schema() == {
"title": "RunnableBindingInput",
"title": "RunnableSequenceInput",
"type": "string",
}