From 207a7b7bbd8ca37bd5c3255c41ab73044460083e Mon Sep 17 00:00:00 2001 From: Ankush Gola Date: Tue, 6 Jun 2023 19:54:43 -0700 Subject: [PATCH] split up batch llm calls into separate runs --- langchain/callbacks/manager.py | 86 ++++++++++++++++++++------------- langchain/llms/base.py | 87 ++++++++++++++++++++++++---------- langchain/schema.py | 12 ++++- 3 files changed, 128 insertions(+), 57 deletions(-) diff --git a/langchain/callbacks/manager.py b/langchain/callbacks/manager.py index 19e957d7346..31a0a232984 100644 --- a/langchain/callbacks/manager.py +++ b/langchain/callbacks/manager.py @@ -641,25 +641,32 @@ class CallbackManager(BaseCallbackManager): prompts: List[str], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> CallbackManagerForLLMRun: + ) -> List[CallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() + managers = [] + for prompt in prompts: + run_id_ = uuid4() + _handle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + **kwargs, + ) - _handle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - prompts, - run_id=run_id, - parent_run_id=self.parent_run_id, - **kwargs, - ) + managers.append( + CallbackManagerForLLMRun( + run_id_, + self.handlers, + self.inheritable_handlers, + self.parent_run_id, + ) + ) - return CallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id - ) + return managers def on_chat_model_start( self, @@ -766,25 +773,40 @@ class AsyncCallbackManager(BaseCallbackManager): prompts: List[str], run_id: Optional[UUID] = None, **kwargs: Any, - ) -> AsyncCallbackManagerForLLMRun: + ) -> List[AsyncCallbackManagerForLLMRun]: """Run when LLM starts running.""" - if run_id is None: - run_id = uuid4() - await _ahandle_event( - self.handlers, - "on_llm_start", - "ignore_llm", - serialized, - prompts, - run_id=run_id, - parent_run_id=self.parent_run_id, - **kwargs, - ) + tasks = [] + managers = [] - return AsyncCallbackManagerForLLMRun( - run_id, self.handlers, self.inheritable_handlers, self.parent_run_id - ) + for prompt in prompts: + run_id_ = uuid4() + + tasks.append( + _ahandle_event( + self.handlers, + "on_llm_start", + "ignore_llm", + serialized, + [prompt], + run_id=run_id_, + parent_run_id=self.parent_run_id, + **kwargs, + ) + ) + + managers.append( + AsyncCallbackManagerForLLMRun( + run_id_, + self.handlers, + self.inheritable_handlers, + self.parent_run_id, + ) + ) + + await asyncio.gather(*tasks) + + return managers async def on_chat_model_start( self, diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 84ba2c5c86d..95a1775cbe1 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -1,4 +1,5 @@ """Base interface for large language models to expose.""" +import asyncio import inspect import json import warnings @@ -178,44 +179,56 @@ class BaseLLM(BaseLanguageModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_manager = callback_manager.on_llm_start( + run_managers = callback_manager.on_llm_start( {"name": self.__class__.__name__}, prompts, invocation_params=params ) try: output = ( - self._generate(prompts, stop=stop, run_manager=run_manager) + self._generate(prompts, stop=stop, run_manager=run_managers[0]) if new_arg_supported else self._generate(prompts, stop=stop) ) except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) + for run_manager in run_managers: + run_manager.on_llm_error(e) raise e - run_manager.on_llm_end(output) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + flattened_outputs = output.flatten() + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] return output if len(missing_prompts) > 0: - run_manager = callback_manager.on_llm_start( + run_managers = callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts, invocation_params=params, ) try: new_results = ( - self._generate(missing_prompts, stop=stop, run_manager=run_manager) + self._generate( + missing_prompts, stop=stop, run_manager=run_managers[0] + ) if new_arg_supported else self._generate(missing_prompts, stop=stop) ) except (KeyboardInterrupt, Exception) as e: - run_manager.on_llm_error(e) + for run_manager in run_managers: + run_manager.on_llm_error(e) raise e - run_manager.on_llm_end(new_results) + flattened_outputs = new_results.flatten() + for manager, flattened_output in zip(run_managers, flattened_outputs): + manager.on_llm_end(flattened_output) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) run_info = None - if run_manager: - run_info = RunInfo(run_id=run_manager.run_id) + if run_managers: + run_info = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] else: llm_output = {} run_info = None @@ -250,24 +263,38 @@ class BaseLLM(BaseLanguageModel, ABC): raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) - run_manager = await callback_manager.on_llm_start( + run_managers = await callback_manager.on_llm_start( {"name": self.__class__.__name__}, prompts, invocation_params=params ) try: output = ( - await self._agenerate(prompts, stop=stop, run_manager=run_manager) + await self._agenerate( + prompts, stop=stop, run_manager=run_managers[0] + ) if new_arg_supported else await self._agenerate(prompts, stop=stop) ) except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e, verbose=self.verbose) + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) raise e - await run_manager.on_llm_end(output, verbose=self.verbose) - if run_manager: - output.run = RunInfo(run_id=run_manager.run_id) + flattened_outputs = output.flatten() + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) + if run_managers: + output.run = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] return output if len(missing_prompts) > 0: - run_manager = await callback_manager.on_llm_start( + run_managers = await callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts, invocation_params=params, @@ -275,21 +302,33 @@ class BaseLLM(BaseLanguageModel, ABC): try: new_results = ( await self._agenerate( - missing_prompts, stop=stop, run_manager=run_manager + missing_prompts, stop=stop, run_manager=run_managers[0] ) if new_arg_supported else await self._agenerate(missing_prompts, stop=stop) ) except (KeyboardInterrupt, Exception) as e: - await run_manager.on_llm_error(e) + await asyncio.gather( + *[run_manager.on_llm_error(e) for run_manager in run_managers] + ) raise e - await run_manager.on_llm_end(new_results) + flattened_outputs = new_results.flatten() + await asyncio.gather( + *[ + run_manager.on_llm_end(flattened_output) + for run_manager, flattened_output in zip( + run_managers, flattened_outputs + ) + ] + ) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) run_info = None - if run_manager: - run_info = RunInfo(run_id=run_manager.run_id) + if run_managers: + run_info = [ + RunInfo(run_id=run_manager.run_id) for run_manager in run_managers + ] else: llm_output = {} run_info = None diff --git a/langchain/schema.py b/langchain/schema.py index b74b40a7c5e..005b96e6b8e 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -182,9 +182,19 @@ class LLMResult(BaseModel): each input could have multiple generations.""" llm_output: Optional[dict] = None """For arbitrary LLM provider specific output.""" - run: Optional[RunInfo] = None + run: Optional[List[RunInfo]] = None """Run metadata.""" + def flatten(self) -> List[LLMResult]: + """Flatten generations into a single list.""" + return [ + LLMResult( + generations=[gen], + llm_output=self.llm_output, + ) + for gen in self.generations + ] + def __eq__(self, other: object) -> bool: if not isinstance(other, LLMResult): return NotImplemented