mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 10:12:33 +00:00
Use call directly for chain (#8655)
for run_on_dataset since the `run()` method requires a single output
This commit is contained in:
parent
368aa4ede7
commit
7ea2b08d1f
@ -604,11 +604,8 @@ async def _arun_chain(
|
|||||||
inputs_, callbacks=callbacks, tags=tags
|
inputs_, callbacks=callbacks, tags=tags
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if len(inputs) == 1:
|
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
|
||||||
inputs_ = next(iter(inputs.values()))
|
output = await chain.acall(inputs_, callbacks=callbacks, tags=tags)
|
||||||
output = await chain.arun(inputs_, callbacks=callbacks, tags=tags)
|
|
||||||
else:
|
|
||||||
output = await chain.acall(inputs, callbacks=callbacks, tags=tags)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -926,11 +923,8 @@ def _run_chain(
|
|||||||
inputs_ = input_mapper(inputs)
|
inputs_ = input_mapper(inputs)
|
||||||
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags)
|
output: Union[dict, str] = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||||
else:
|
else:
|
||||||
if len(inputs) == 1:
|
inputs_ = next(iter(inputs.values())) if len(inputs) == 1 else inputs
|
||||||
inputs_ = next(iter(inputs.values()))
|
output = chain(inputs_, callbacks=callbacks, tags=tags)
|
||||||
output = chain.run(inputs_, callbacks=callbacks, tags=tags)
|
|
||||||
else:
|
|
||||||
output = chain(inputs, callbacks=callbacks, tags=tags)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user