diff --git a/libs/langchain/langchain/chains/base.py b/libs/langchain/langchain/chains/base.py index aab2891b67e..8d4f4e708ea 100644 --- a/libs/langchain/langchain/chains/base.py +++ b/libs/langchain/langchain/chains/base.py @@ -204,7 +204,7 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): if new_arg_supported else await self._acall(inputs) ) - final_outputs: Dict[str, Any] = self.prep_outputs( + final_outputs: Dict[str, Any] = await self.aprep_outputs( inputs, outputs, return_only_outputs ) except BaseException as e: @@ -458,6 +458,32 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC): else: return {**inputs, **outputs} + async def aprep_outputs( + self, + inputs: Dict[str, str], + outputs: Dict[str, str], + return_only_outputs: bool = False, + ) -> Dict[str, str]: + """Validate and prepare chain outputs, and save info about this run to memory. + + Args: + inputs: Dictionary of chain inputs, including any inputs added by chain + memory. + outputs: Dictionary of initial chain outputs. + return_only_outputs: Whether to only return the chain outputs. If False, + inputs are also added to the final outputs. + + Returns: + A dict of the final chain outputs. + """ + self._validate_outputs(outputs) + if self.memory is not None: + await self.memory.asave_context(inputs, outputs) + if return_only_outputs: + return outputs + else: + return {**inputs, **outputs} + def prep_inputs(self, inputs: Union[Dict[str, Any], Any]) -> Dict[str, str]: """Prepare chain inputs, including adding inputs from memory.