split up batch llm calls into separate runs

This commit is contained in:
Ankush Gola
2023-06-06 19:54:43 -07:00
parent b177a29d3f
commit 207a7b7bbd
3 changed files with 128 additions and 57 deletions

View File

@@ -641,25 +641,32 @@ class CallbackManager(BaseCallbackManager):
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> CallbackManagerForLLMRun:
) -> List[CallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
managers = []
for prompt in prompts:
run_id_ = uuid4()
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
_handle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
managers.append(
CallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
return CallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
return managers
def on_chat_model_start(
self,
@@ -766,25 +773,40 @@ class AsyncCallbackManager(BaseCallbackManager):
prompts: List[str],
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> AsyncCallbackManagerForLLMRun:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running."""
if run_id is None:
run_id = uuid4()
await _ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
prompts,
run_id=run_id,
parent_run_id=self.parent_run_id,
**kwargs,
)
tasks = []
managers = []
return AsyncCallbackManagerForLLMRun(
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
)
for prompt in prompts:
run_id_ = uuid4()
tasks.append(
_ahandle_event(
self.handlers,
"on_llm_start",
"ignore_llm",
serialized,
[prompt],
run_id=run_id_,
parent_run_id=self.parent_run_id,
**kwargs,
)
)
managers.append(
AsyncCallbackManagerForLLMRun(
run_id_,
self.handlers,
self.inheritable_handlers,
self.parent_run_id,
)
)
await asyncio.gather(*tasks)
return managers
async def on_chat_model_start(
self,

View File

@@ -1,4 +1,5 @@
"""Base interface for large language models to expose."""
import asyncio
import inspect
import json
import warnings
@@ -178,44 +179,56 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
self._generate(prompts, stop=stop, run_manager=run_manager)
self._generate(prompts, stop=stop, run_manager=run_managers[0])
if new_arg_supported
else self._generate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
flattened_outputs = output.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
run_managers = callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
invocation_params=params,
)
try:
new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
self._generate(
missing_prompts, stop=stop, run_manager=run_managers[0]
)
if new_arg_supported
else self._generate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
for run_manager in run_managers:
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(new_results)
flattened_outputs = new_results.flatten()
for manager, flattened_output in zip(run_managers, flattened_outputs):
manager.on_llm_end(flattened_output)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
if run_managers:
run_info = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
else:
llm_output = {}
run_info = None
@@ -250,24 +263,38 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, invocation_params=params
)
try:
output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
await self._agenerate(
prompts, stop=stop, run_manager=run_managers[0]
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose)
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
flattened_outputs = output.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
if run_managers:
output.run = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
run_managers = await callback_manager.on_llm_start(
{"name": self.__class__.__name__},
missing_prompts,
invocation_params=params,
@@ -275,21 +302,33 @@ class BaseLLM(BaseLanguageModel, ABC):
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager
missing_prompts, stop=stop, run_manager=run_managers[0]
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
await asyncio.gather(
*[run_manager.on_llm_error(e) for run_manager in run_managers]
)
raise e
await run_manager.on_llm_end(new_results)
flattened_outputs = new_results.flatten()
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)
for run_manager, flattened_output in zip(
run_managers, flattened_outputs
)
]
)
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
if run_managers:
run_info = [
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
]
else:
llm_output = {}
run_info = None

View File

@@ -182,9 +182,19 @@ class LLMResult(BaseModel):
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
run: Optional[List[RunInfo]] = None
"""Run metadata."""
def flatten(self) -> List[LLMResult]:
"""Flatten generations into a single list."""
return [
LLMResult(
generations=[gen],
llm_output=self.llm_output,
)
for gen in self.generations
]
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented