Nc/runnables seqmap tags (#9753)

This commit is contained in:
Nuno Campos 2023-09-01 18:53:10 +01:00 committed by GitHub
parent b927277809
commit 427f696fb0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 81 additions and 31 deletions

View File

@ -974,11 +974,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke all steps in sequence # invoke all steps in sequence
try: try:
for step in self.steps: for i, step in enumerate(self.steps):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -1004,11 +1006,13 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# invoke all steps in sequence # invoke all steps in sequence
try: try:
for step in self.steps: for i, step in enumerate(self.steps):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
) )
# finish the root run # finish the root run
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
@ -1059,7 +1063,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# If an input has failed it will be present in this map, # If an input has failed it will be present in this map,
# and the value will be the exception that was raised. # and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {} failed_inputs_map: Dict[int, Exception] = {}
for step in self.steps: for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs # Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet) # (i.e. the ones that haven't failed yet)
remaining_idxs = [ remaining_idxs = [
@ -1074,7 +1078,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
], ],
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child()) patch_config(
config, callbacks=rm.get_child(f"seq:step:{stepidx+1}")
)
for i, (rm, config) in enumerate(zip(run_managers, configs)) for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map if i not in failed_inputs_map
], ],
@ -1099,12 +1105,14 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else: else:
inputs.append(inputs_copy.pop(0)) inputs.append(inputs_copy.pop(0))
else: else:
for step in self.steps: for i, step in enumerate(self.steps):
inputs = step.batch( inputs = step.batch(
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child()) patch_config(
config, callbacks=rm.get_child(f"seq:step:{i+1}")
)
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
) )
@ -1176,7 +1184,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
# If an input has failed it will be present in this map, # If an input has failed it will be present in this map,
# and the value will be the exception that was raised. # and the value will be the exception that was raised.
failed_inputs_map: Dict[int, Exception] = {} failed_inputs_map: Dict[int, Exception] = {}
for step in self.steps: for stepidx, step in enumerate(self.steps):
# Assemble the original indexes of the remaining inputs # Assemble the original indexes of the remaining inputs
# (i.e. the ones that haven't failed yet) # (i.e. the ones that haven't failed yet)
remaining_idxs = [ remaining_idxs = [
@ -1191,7 +1199,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
], ],
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child()) patch_config(
config, callbacks=rm.get_child(f"seq:step:{stepidx+1}")
)
for i, (rm, config) in enumerate(zip(run_managers, configs)) for i, (rm, config) in enumerate(zip(run_managers, configs))
if i not in failed_inputs_map if i not in failed_inputs_map
], ],
@ -1216,12 +1226,14 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
else: else:
inputs.append(inputs_copy.pop(0)) inputs.append(inputs_copy.pop(0))
else: else:
for step in self.steps: for i, step in enumerate(self.steps):
inputs = await step.abatch( inputs = await step.abatch(
inputs, inputs,
[ [
# each step a child run of the corresponding root run # each step a child run of the corresponding root run
patch_config(config, callbacks=rm.get_child()) patch_config(
config, callbacks=rm.get_child(f"seq:step:{i+1}")
)
for rm, config in zip(run_managers, configs) for rm, config in zip(run_managers, configs)
], ],
) )
@ -1276,7 +1288,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = step.invoke( input = step.invoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
@ -1288,13 +1305,24 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].stream( final_pipeline = steps[streaming_start_index].stream(
input, patch_config(config, callbacks=run_manager.get_child()) input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.transform( final_pipeline = step.transform(
final_pipeline, final_pipeline,
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
) )
for output in final_pipeline: for output in final_pipeline:
yield output yield output
@ -1345,7 +1373,12 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
input = await step.ainvoke( input = await step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
@ -1357,13 +1390,24 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
try: try:
# stream the first of the last steps with non-streaming input # stream the first of the last steps with non-streaming input
final_pipeline = steps[streaming_start_index].astream( final_pipeline = steps[streaming_start_index].astream(
input, patch_config(config, callbacks=run_manager.get_child()) input,
patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{streaming_start_index+1}"
),
),
) )
# stream the rest of the last steps with streaming input # stream the rest of the last steps with streaming input
for step in steps[streaming_start_index + 1 :]: for step in steps[streaming_start_index + 1 :]:
final_pipeline = step.atransform( final_pipeline = step.atransform(
final_pipeline, final_pipeline,
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{steps.index(step)+1}"
),
),
) )
async for output in final_pipeline: async for output in final_pipeline:
yield output yield output
@ -1476,10 +1520,10 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
patch_config( patch_config(
config, config,
deep_copy_locals=True, deep_copy_locals=True,
callbacks=run_manager.get_child(), callbacks=run_manager.get_child(f"map:key:{key}"),
), ),
) )
for step in steps.values() for key, step in steps.items()
] ]
output = {key: future.result() for key, future in zip(steps, futures)} output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run # finish the root run
@ -1513,9 +1557,11 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
step.ainvoke( step.ainvoke(
input, input,
# mark each step as a child run # mark each step as a child run
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(f"map:key:{key}")
),
) )
for step in steps.values() for key, step in steps.items()
) )
) )
output = {key: value for key, value in zip(steps, results)} output = {key: value for key, value in zip(steps, results)}
@ -1545,7 +1591,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
name, name,
step.transform( step.transform(
input_copies.pop(), input_copies.pop(),
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(f"map:key:{name}")
),
), ),
) )
for name, step in steps.items() for name, step in steps.items()
@ -1607,7 +1655,9 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
name, name,
step.atransform( step.atransform(
input_copies.pop(), input_copies.pop(),
patch_config(config, callbacks=run_manager.get_child()), patch_config(
config, callbacks=run_manager.get_child(f"map:key:{name}")
),
), ),
) )
for name, step in steps.items() for name, step in steps.items()

File diff suppressed because one or more lines are too long