mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 12:58:59 +00:00
Allow calls to batch() with 0 length arrays (#10627)
This can happen if eg the input to batch is a list generated dynamically, where a 0-length list might be a valid use case
This commit is contained in:
parent
a50e62e44b
commit
029b2f6aac
@ -267,6 +267,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
config = get_config_list(config, len(inputs))
|
config = get_config_list(config, len(inputs))
|
||||||
max_concurrency = config[0].get("max_concurrency")
|
max_concurrency = config[0].get("max_concurrency")
|
||||||
|
|
||||||
@ -306,6 +309,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
if type(self)._agenerate == BaseLLM._agenerate:
|
if type(self)._agenerate == BaseLLM._agenerate:
|
||||||
# model doesn't implement async batch, so use default implementation
|
# model doesn't implement async batch, so use default implementation
|
||||||
return await asyncio.get_running_loop().run_in_executor(
|
return await asyncio.get_running_loop().run_in_executor(
|
||||||
|
@ -114,6 +114,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Default implementation of batch, which calls invoke N times.
|
Default implementation of batch, which calls invoke N times.
|
||||||
Subclasses should override this method if they can batch more efficiently.
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
"""
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||||
@ -144,6 +147,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Default implementation of abatch, which calls ainvoke N times.
|
Default implementation of abatch, which calls ainvoke N times.
|
||||||
Subclasses should override this method if they can batch more efficiently.
|
Subclasses should override this method if they can batch more efficiently.
|
||||||
"""
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
@ -376,6 +382,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
|
if not input:
|
||||||
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(input))
|
configs = get_config_list(config, len(input))
|
||||||
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||||
run_managers = [
|
run_managers = [
|
||||||
@ -444,6 +453,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
"""Helper method to transform an Input value to an Output value,
|
"""Helper method to transform an Input value to an Output value,
|
||||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||||
|
if not input:
|
||||||
|
return []
|
||||||
|
|
||||||
configs = get_config_list(config, len(input))
|
configs = get_config_list(config, len(input))
|
||||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||||
@ -748,6 +760,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
@ -813,6 +828,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
|||||||
if return_exceptions:
|
if return_exceptions:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
@ -1004,6 +1022,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
@ -1122,6 +1143,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
|||||||
AsyncCallbackManager,
|
AsyncCallbackManager,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
# setup callbacks
|
# setup callbacks
|
||||||
configs = get_config_list(config, len(inputs))
|
configs = get_config_list(config, len(inputs))
|
||||||
callback_managers = [
|
callback_managers = [
|
||||||
|
@ -97,8 +97,8 @@ def get_config_list(
|
|||||||
Helper method to get a list of configs from a single config or a list of
|
Helper method to get a list of configs from a single config or a list of
|
||||||
configs, useful for subclasses overriding batch() or abatch().
|
configs, useful for subclasses overriding batch() or abatch().
|
||||||
"""
|
"""
|
||||||
if length < 1:
|
if length < 0:
|
||||||
raise ValueError(f"length must be >= 1, but got {length}")
|
raise ValueError(f"length must be >= 0, but got {length}")
|
||||||
if isinstance(config, list) and len(config) != length:
|
if isinstance(config, list) and len(config) != length:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"config must be a list of the same length as inputs, "
|
f"config must be a list of the same length as inputs, "
|
||||||
|
@ -129,6 +129,9 @@ class RouterRunnable(
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input["key"] for input in inputs]
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
@ -161,6 +164,9 @@ class RouterRunnable(
|
|||||||
return_exceptions: bool = False,
|
return_exceptions: bool = False,
|
||||||
**kwargs: Optional[Any],
|
**kwargs: Optional[Any],
|
||||||
) -> List[Output]:
|
) -> List[Output]:
|
||||||
|
if not inputs:
|
||||||
|
return []
|
||||||
|
|
||||||
keys = [input["key"] for input in inputs]
|
keys = [input["key"] for input in inputs]
|
||||||
actual_inputs = [input["input"] for input in inputs]
|
actual_inputs = [input["input"] for input in inputs]
|
||||||
if any(key not in self.runnables for key in keys):
|
if any(key not in self.runnables for key in keys):
|
||||||
|
Loading…
Reference in New Issue
Block a user