From ed1bb3d8d0a333c1824ff24f7b84e8f124b9333c Mon Sep 17 00:00:00 2001 From: Erick Friis Date: Mon, 18 Dec 2023 13:50:39 -0800 Subject: [PATCH] proposal (#14729) --- .../language_models/chat_models.py | 62 +++++++------------ 1 file changed, 22 insertions(+), 40 deletions(-) diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index f97ee7ba76a..f33c7782c65 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -428,7 +428,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): batch_size=len(messages), ) - results = await asyncio.gather( + results_and_exceptions = await asyncio.gather( *[ self._agenerate_with_cache( m, @@ -440,47 +440,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ], return_exceptions=True, ) - exceptions = [] - for i, res in enumerate(results): - if isinstance(res, BaseException): - if run_managers: - await run_managers[i].on_llm_error( - res, response=LLMResult(generations=[]) - ) - exceptions.append(res) - if exceptions: - if run_managers: - await asyncio.gather( - *[ - run_manager.on_llm_end( - LLMResult( - generations=[ - cast(ChatResult, res).generations, - ], - llm_output=cast(ChatResult, res).llm_output, - ) - ) - for run_manager, res in zip(run_managers, results) - if not isinstance(res, Exception) - ] - ) - raise exceptions[0] - flattened_outputs = [ - LLMResult(generations=[res.generations], llm_output=res.llm_output) - for res in cast(List[ChatResult], results) - ] - llm_output = self._combine_llm_outputs( - [res.llm_output for res in cast(List[ChatResult], results)] - ) - generations = [res.generations for res in cast(List[ChatResult], results)] - output = LLMResult(generations=generations, llm_output=llm_output) - await asyncio.gather( - *[ - run_manager.on_llm_end(flattened_output) - for run_manager, flattened_output in zip( - run_managers, flattened_outputs + # report results and errors + if run_managers: + jobs = [ + run_manager.on_llm_error(res, response=LLMResult(generations=[])) + if isinstance(res, BaseException) + else run_manager.on_llm_end( + LLMResult(generations=[res.generations], llm_output=res.llm_output) ) + for run_manager, res in zip(run_managers, results_and_exceptions) ] + await asyncio.gather(*jobs) + + # raise first exception, if any + for res in results_and_exceptions: + if isinstance(res, BaseException): + raise res + + # compute return value + results = cast(List[ChatResult], results_and_exceptions) + + output = LLMResult( + generations=[res.generations for res in results], + llm_output=self._combine_llm_outputs([res.llm_output for res in results]), ) if run_managers: output.run = [