change _agenerate in to use gather

This commit is contained in:
Ankush Gola
2023-06-10 16:41:43 -07:00
parent b3475d6b50
commit 0083ff9c7d

View File

@@ -494,13 +494,17 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts:
text = (
async def get_completion(prompt):
return (
await self._acall(prompt, stop=stop, run_manager=run_manager)
if new_arg_supported
else await self._acall(prompt, stop=stop)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
generations = await asyncio.gather(
*(get_completion(prompt) for prompt in prompts)
)
return LLMResult(generations=[[Generation(text=text)] for text in generations])