Fix runnable branch callbacks (#11091)

We aren't calling on_chain_end here unless we use the default option
This commit is contained in:
William FH 2023-09-27 03:38:56 -07:00 committed by GitHub
parent 5514ebe859
commit 75b3893daf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -867,14 +867,15 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
) )
if expression_value: if expression_value:
return runnable.invoke( output = runnable.invoke(
input, input,
config=patch_config( config=patch_config(
config, config,
callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"), callbacks=run_manager.get_child(tag=f"branch:{idx + 1}"),
), ),
) )
break
else:
output = self.default.invoke( output = self.default.invoke(
input, input,
config=patch_config( config=patch_config(
@ -911,7 +912,7 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
) )
if expression_value: if expression_value:
return await runnable.ainvoke( output = await runnable.ainvoke(
input, input,
config=patch_config( config=patch_config(
config, config,
@ -919,7 +920,8 @@ class RunnableBranch(Serializable, Runnable[Input, Output]):
), ),
**kwargs, **kwargs,
) )
break
else:
output = await self.default.ainvoke( output = await self.default.ainvoke(
input, input,
config=patch_config( config=patch_config(