mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 19:18:48 +00:00
split up batch llm calls into separate runs
This commit is contained in:
@@ -641,25 +641,32 @@ class CallbackManager(BaseCallbackManager):
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> CallbackManagerForLLMRun:
|
||||
) -> List[CallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
managers = []
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
_handle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
managers.append(
|
||||
CallbackManagerForLLMRun(
|
||||
run_id_,
|
||||
self.handlers,
|
||||
self.inheritable_handlers,
|
||||
self.parent_run_id,
|
||||
)
|
||||
)
|
||||
|
||||
return CallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
return managers
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
@@ -766,25 +773,40 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
prompts: List[str],
|
||||
run_id: Optional[UUID] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncCallbackManagerForLLMRun:
|
||||
) -> List[AsyncCallbackManagerForLLMRun]:
|
||||
"""Run when LLM starts running."""
|
||||
if run_id is None:
|
||||
run_id = uuid4()
|
||||
|
||||
await _ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
prompts,
|
||||
run_id=run_id,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
tasks = []
|
||||
managers = []
|
||||
|
||||
return AsyncCallbackManagerForLLMRun(
|
||||
run_id, self.handlers, self.inheritable_handlers, self.parent_run_id
|
||||
)
|
||||
for prompt in prompts:
|
||||
run_id_ = uuid4()
|
||||
|
||||
tasks.append(
|
||||
_ahandle_event(
|
||||
self.handlers,
|
||||
"on_llm_start",
|
||||
"ignore_llm",
|
||||
serialized,
|
||||
[prompt],
|
||||
run_id=run_id_,
|
||||
parent_run_id=self.parent_run_id,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
managers.append(
|
||||
AsyncCallbackManagerForLLMRun(
|
||||
run_id_,
|
||||
self.handlers,
|
||||
self.inheritable_handlers,
|
||||
self.parent_run_id,
|
||||
)
|
||||
)
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return managers
|
||||
|
||||
async def on_chat_model_start(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Base interface for large language models to expose."""
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import warnings
|
||||
@@ -178,44 +179,56 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
self._generate(prompts, stop=stop, run_manager=run_manager)
|
||||
self._generate(prompts, stop=stop, run_manager=run_managers[0])
|
||||
if new_arg_supported
|
||||
else self._generate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
flattened_outputs = output.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
run_managers = callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
)
|
||||
try:
|
||||
new_results = (
|
||||
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
|
||||
self._generate(
|
||||
missing_prompts, stop=stop, run_manager=run_managers[0]
|
||||
)
|
||||
if new_arg_supported
|
||||
else self._generate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_llm_error(e)
|
||||
for run_manager in run_managers:
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(new_results)
|
||||
flattened_outputs = new_results.flatten()
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
manager.on_llm_end(flattened_output)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
if run_managers:
|
||||
run_info = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
@@ -250,24 +263,38 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__}, prompts, invocation_params=params
|
||||
)
|
||||
try:
|
||||
output = (
|
||||
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
|
||||
await self._agenerate(
|
||||
prompts, stop=stop, run_manager=run_managers[0]
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
flattened_outputs = output.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
if run_managers:
|
||||
output.run = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
run_managers = await callback_manager.on_llm_start(
|
||||
{"name": self.__class__.__name__},
|
||||
missing_prompts,
|
||||
invocation_params=params,
|
||||
@@ -275,21 +302,33 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
try:
|
||||
new_results = (
|
||||
await self._agenerate(
|
||||
missing_prompts, stop=stop, run_manager=run_manager
|
||||
missing_prompts, stop=stop, run_manager=run_managers[0]
|
||||
)
|
||||
if new_arg_supported
|
||||
else await self._agenerate(missing_prompts, stop=stop)
|
||||
)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
await run_manager.on_llm_error(e)
|
||||
await asyncio.gather(
|
||||
*[run_manager.on_llm_error(e) for run_manager in run_managers]
|
||||
)
|
||||
raise e
|
||||
await run_manager.on_llm_end(new_results)
|
||||
flattened_outputs = new_results.flatten()
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
for run_manager, flattened_output in zip(
|
||||
run_managers, flattened_outputs
|
||||
)
|
||||
]
|
||||
)
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
if run_managers:
|
||||
run_info = [
|
||||
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
||||
]
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
|
||||
@@ -182,9 +182,19 @@ class LLMResult(BaseModel):
|
||||
each input could have multiple generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
run: Optional[RunInfo] = None
|
||||
run: Optional[List[RunInfo]] = None
|
||||
"""Run metadata."""
|
||||
|
||||
def flatten(self) -> List[LLMResult]:
|
||||
"""Flatten generations into a single list."""
|
||||
return [
|
||||
LLMResult(
|
||||
generations=[gen],
|
||||
llm_output=self.llm_output,
|
||||
)
|
||||
for gen in self.generations
|
||||
]
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
|
||||
Reference in New Issue
Block a user