mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
Allow calls to batch() with 0 length arrays (#10627)
This can happen if eg the input to batch is a list generated dynamically, where a 0-length list might be a valid use case
This commit is contained in:
parent
a50e62e44b
commit
029b2f6aac
@ -267,6 +267,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
config = get_config_list(config, len(inputs))
|
||||
max_concurrency = config[0].get("max_concurrency")
|
||||
|
||||
@ -306,6 +309,9 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> List[str]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
if type(self)._agenerate == BaseLLM._agenerate:
|
||||
# model doesn't implement async batch, so use default implementation
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
|
@ -114,6 +114,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Default implementation of batch, which calls invoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
|
||||
@ -144,6 +147,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Default implementation of abatch, which calls ainvoke N times.
|
||||
Subclasses should override this method if they can batch more efficiently.
|
||||
"""
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(inputs))
|
||||
|
||||
async def ainvoke(
|
||||
@ -376,6 +382,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
if not input:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(input))
|
||||
callback_managers = [get_callback_manager_for_config(c) for c in configs]
|
||||
run_managers = [
|
||||
@ -444,6 +453,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
) -> List[Output]:
|
||||
"""Helper method to transform an Input value to an Output value,
|
||||
with callbacks. Use this method to implement invoke() in subclasses."""
|
||||
if not input:
|
||||
return []
|
||||
|
||||
configs = get_config_list(config, len(input))
|
||||
callback_managers = [get_async_callback_manager_for_config(c) for c in configs]
|
||||
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
|
||||
@ -748,6 +760,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
@ -813,6 +828,9 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
|
||||
if return_exceptions:
|
||||
raise NotImplementedError()
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
@ -1004,6 +1022,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
) -> List[Output]:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
@ -1122,6 +1143,9 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
|
||||
AsyncCallbackManager,
|
||||
)
|
||||
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
# setup callbacks
|
||||
configs = get_config_list(config, len(inputs))
|
||||
callback_managers = [
|
||||
|
@ -97,8 +97,8 @@ def get_config_list(
|
||||
Helper method to get a list of configs from a single config or a list of
|
||||
configs, useful for subclasses overriding batch() or abatch().
|
||||
"""
|
||||
if length < 1:
|
||||
raise ValueError(f"length must be >= 1, but got {length}")
|
||||
if length < 0:
|
||||
raise ValueError(f"length must be >= 0, but got {length}")
|
||||
if isinstance(config, list) and len(config) != length:
|
||||
raise ValueError(
|
||||
f"config must be a list of the same length as inputs, "
|
||||
|
@ -129,6 +129,9 @@ class RouterRunnable(
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
@ -161,6 +164,9 @@ class RouterRunnable(
|
||||
return_exceptions: bool = False,
|
||||
**kwargs: Optional[Any],
|
||||
) -> List[Output]:
|
||||
if not inputs:
|
||||
return []
|
||||
|
||||
keys = [input["key"] for input in inputs]
|
||||
actual_inputs = [input["input"] for input in inputs]
|
||||
if any(key not in self.runnables for key in keys):
|
||||
|
Loading…
Reference in New Issue
Block a user