Allow calls to batch() with 0 length arrays (#10627)

This can happen if eg the input to batch is a list generated dynamically, where a 0-length list might be a valid use case
This commit is contained in:
Nuno Campos 2023-09-15 17:37:27 +01:00 committed by GitHub
parent a50e62e44b
commit 029b2f6aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 38 additions and 2 deletions

View File

@ -267,6 +267,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
if not inputs:
return []
config = get_config_list(config, len(inputs)) config = get_config_list(config, len(inputs))
max_concurrency = config[0].get("max_concurrency") max_concurrency = config[0].get("max_concurrency")
@ -306,6 +309,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Any, **kwargs: Any,
) -> List[str]: ) -> List[str]:
if not inputs:
return []
if type(self)._agenerate == BaseLLM._agenerate: if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation # model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor( return await asyncio.get_running_loop().run_in_executor(

View File

@ -114,6 +114,9 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of batch, which calls invoke N times. Default implementation of batch, which calls invoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
if not inputs:
return []
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]: def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
@ -144,6 +147,9 @@ class Runnable(Generic[Input, Output], ABC):
Default implementation of abatch, which calls ainvoke N times. Default implementation of abatch, which calls ainvoke N times.
Subclasses should override this method if they can batch more efficiently. Subclasses should override this method if they can batch more efficiently.
""" """
if not inputs:
return []
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
async def ainvoke( async def ainvoke(
@ -376,6 +382,9 @@ class Runnable(Generic[Input, Output], ABC):
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
return []
configs = get_config_list(config, len(input)) configs = get_config_list(config, len(input))
callback_managers = [get_callback_manager_for_config(c) for c in configs] callback_managers = [get_callback_manager_for_config(c) for c in configs]
run_managers = [ run_managers = [
@ -444,6 +453,9 @@ class Runnable(Generic[Input, Output], ABC):
) -> List[Output]: ) -> List[Output]:
"""Helper method to transform an Input value to an Output value, """Helper method to transform an Input value to an Output value,
with callbacks. Use this method to implement invoke() in subclasses.""" with callbacks. Use this method to implement invoke() in subclasses."""
if not input:
return []
configs = get_config_list(config, len(input)) configs = get_config_list(config, len(input))
callback_managers = [get_async_callback_manager_for_config(c) for c in configs] callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather( run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
@ -748,6 +760,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
if return_exceptions: if return_exceptions:
raise NotImplementedError() raise NotImplementedError()
if not inputs:
return []
# setup callbacks # setup callbacks
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
@ -813,6 +828,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
if return_exceptions: if return_exceptions:
raise NotImplementedError() raise NotImplementedError()
if not inputs:
return []
# setup callbacks # setup callbacks
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
@ -1004,6 +1022,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
) -> List[Output]: ) -> List[Output]:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
if not inputs:
return []
# setup callbacks # setup callbacks
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [
@ -1122,6 +1143,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
AsyncCallbackManager, AsyncCallbackManager,
) )
if not inputs:
return []
# setup callbacks # setup callbacks
configs = get_config_list(config, len(inputs)) configs = get_config_list(config, len(inputs))
callback_managers = [ callback_managers = [

View File

@ -97,8 +97,8 @@ def get_config_list(
Helper method to get a list of configs from a single config or a list of Helper method to get a list of configs from a single config or a list of
configs, useful for subclasses overriding batch() or abatch(). configs, useful for subclasses overriding batch() or abatch().
""" """
if length < 1: if length < 0:
raise ValueError(f"length must be >= 1, but got {length}") raise ValueError(f"length must be >= 0, but got {length}")
if isinstance(config, list) and len(config) != length: if isinstance(config, list) and len(config) != length:
raise ValueError( raise ValueError(
f"config must be a list of the same length as inputs, " f"config must be a list of the same length as inputs, "

View File

@ -129,6 +129,9 @@ class RouterRunnable(
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs] keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys): if any(key not in self.runnables for key in keys):
@ -161,6 +164,9 @@ class RouterRunnable(
return_exceptions: bool = False, return_exceptions: bool = False,
**kwargs: Optional[Any], **kwargs: Optional[Any],
) -> List[Output]: ) -> List[Output]:
if not inputs:
return []
keys = [input["key"] for input in inputs] keys = [input["key"] for input in inputs]
actual_inputs = [input["input"] for input in inputs] actual_inputs = [input["input"] for input in inputs]
if any(key not in self.runnables for key in keys): if any(key not in self.runnables for key in keys):