This commit is contained in:
Erick Friis
2023-12-18 13:50:39 -08:00
committed by GitHub
parent 251c81b6a9
commit ed1bb3d8d0

View File

@@ -428,7 +428,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
batch_size=len(messages),
)
results = await asyncio.gather(
results_and_exceptions = await asyncio.gather(
*[
self._agenerate_with_cache(
m,
@@ -440,47 +440,29 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
],
return_exceptions=True,
)
exceptions = []
for i, res in enumerate(results):
if isinstance(res, BaseException):
if run_managers:
await run_managers[i].on_llm_error(
res, response=LLMResult(generations=[])
)
exceptions.append(res)
if exceptions:
if run_managers:
await asyncio.gather(
*[
run_manager.on_llm_end(
LLMResult(
generations=[
cast(ChatResult, res).generations,
],
llm_output=cast(ChatResult, res).llm_output,
)
)
for run_manager, res in zip(run_managers, results)
if not isinstance(res, Exception)
]
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in cast(List[ChatResult], results)
]
llm_output = self._combine_llm_outputs(
[res.llm_output for res in cast(List[ChatResult], results)]
)
generations = [res.generations for res in cast(List[ChatResult], results)]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
# report results and errors
if run_managers:
jobs = [
run_manager.on_llm_error(res, response=LLMResult(generations=[]))
if isinstance(res, BaseException)
else run_manager.on_llm_end(
LLMResult(generations=[res.generations], llm_output=res.llm_output)
)
for run_manager, res in zip(run_managers, results_and_exceptions)
]
await asyncio.gather(*jobs)
# raise first exception, if any
for res in results_and_exceptions:
if isinstance(res, BaseException):
raise res
# compute return value
results = cast(List[ChatResult], results_and_exceptions)
output = LLMResult(
generations=[res.generations for res in results],
llm_output=self._combine_llm_outputs([res.llm_output for res in results]),
)
if run_managers:
output.run = [