mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-19 00:58:32 +00:00
Make it easier to subclass runnable binding with custom init args (#13189)
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
|
||||||
|
|
||||||
|
|
||||||
class HubRunnable(RunnableBinding[Input, Output]):
|
class HubRunnable(RunnableBindingBase[Input, Output]):
|
||||||
"""
|
"""
|
||||||
An instance of a runnable stored in the LangChain Hub.
|
An instance of a runnable stored in the LangChain Hub.
|
||||||
"""
|
"""
|
||||||
|
@@ -5,7 +5,8 @@ from typing_extensions import TypedDict
|
|||||||
|
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
from langchain.schema.runnable import RouterRunnable, Runnable
|
||||||
|
from langchain.schema.runnable.base import RunnableBindingBase
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunction(TypedDict):
|
class OpenAIFunction(TypedDict):
|
||||||
@@ -19,7 +20,7 @@ class OpenAIFunction(TypedDict):
|
|||||||
"""The parameters to the function."""
|
"""The parameters to the function."""
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
|
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
|
||||||
"""A runnable that routes to the selected function."""
|
"""A runnable that routes to the selected function."""
|
||||||
|
|
||||||
functions: Optional[List[OpenAIFunction]]
|
functions: Optional[List[OpenAIFunction]]
|
||||||
|
@@ -2581,11 +2581,6 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.bound.config_schema(include=include)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -2659,7 +2654,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
|||||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(RunnableSerializable[Input, Output]):
|
class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
"""
|
"""
|
||||||
@@ -2749,11 +2744,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||||
return self.bound.config_specs
|
return self.bound.config_specs
|
||||||
|
|
||||||
def config_schema(
|
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.bound.config_schema(include=include)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -2762,93 +2752,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
return cls.__module__.split(".")[:-1]
|
return cls.__module__.split(".")[:-1]
|
||||||
|
|
||||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
|
||||||
return self.__class__(
|
|
||||||
bound=self.bound,
|
|
||||||
config=self.config,
|
|
||||||
kwargs={**self.kwargs, **kwargs},
|
|
||||||
custom_input_type=self.custom_input_type,
|
|
||||||
custom_output_type=self.custom_output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_config(
|
|
||||||
self,
|
|
||||||
config: Optional[RunnableConfig] = None,
|
|
||||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Runnable[Input, Output]:
|
|
||||||
return self.__class__(
|
|
||||||
bound=self.bound,
|
|
||||||
kwargs=self.kwargs,
|
|
||||||
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
|
||||||
custom_input_type=self.custom_input_type,
|
|
||||||
custom_output_type=self.custom_output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_listeners(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
on_start: Optional[Listener] = None,
|
|
||||||
on_end: Optional[Listener] = None,
|
|
||||||
on_error: Optional[Listener] = None,
|
|
||||||
) -> Runnable[Input, Output]:
|
|
||||||
"""
|
|
||||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
|
||||||
|
|
||||||
on_start: Called before the runnable starts running, with the Run object.
|
|
||||||
on_end: Called after the runnable finishes running, with the Run object.
|
|
||||||
on_error: Called if the runnable throws an error, with the Run object.
|
|
||||||
|
|
||||||
The Run object contains information about the run, including its id,
|
|
||||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
|
||||||
added to the run.
|
|
||||||
"""
|
|
||||||
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
|
||||||
|
|
||||||
return self.__class__(
|
|
||||||
bound=self.bound,
|
|
||||||
kwargs=self.kwargs,
|
|
||||||
config=self.config,
|
|
||||||
config_factories=[
|
|
||||||
lambda config: {
|
|
||||||
"callbacks": [
|
|
||||||
RootListenersTracer(
|
|
||||||
config=config,
|
|
||||||
on_start=on_start,
|
|
||||||
on_end=on_end,
|
|
||||||
on_error=on_error,
|
|
||||||
)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
custom_input_type=self.custom_input_type,
|
|
||||||
custom_output_type=self.custom_output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_types(
|
|
||||||
self,
|
|
||||||
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
|
||||||
output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
|
||||||
) -> Runnable[Input, Output]:
|
|
||||||
return self.__class__(
|
|
||||||
bound=self.bound,
|
|
||||||
kwargs=self.kwargs,
|
|
||||||
config=self.config,
|
|
||||||
custom_input_type=input_type
|
|
||||||
if input_type is not None
|
|
||||||
else self.custom_input_type,
|
|
||||||
custom_output_type=output_type
|
|
||||||
if output_type is not None
|
|
||||||
else self.custom_output_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
|
||||||
return self.__class__(
|
|
||||||
bound=self.bound.with_retry(**kwargs),
|
|
||||||
kwargs=self.kwargs,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||||
config = merge_configs(self.config, *configs)
|
config = merge_configs(self.config, *configs)
|
||||||
return merge_configs(config, *(f(config) for f in self.config_factories))
|
return merge_configs(config, *(f(config) for f in self.config_factories))
|
||||||
@@ -2972,7 +2875,97 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
|||||||
yield item
|
yield item
|
||||||
|
|
||||||
|
|
||||||
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||||
|
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
config=self.config,
|
||||||
|
kwargs={**self.kwargs, **kwargs},
|
||||||
|
custom_input_type=self.custom_input_type,
|
||||||
|
custom_output_type=self.custom_output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_config(
|
||||||
|
self,
|
||||||
|
config: Optional[RunnableConfig] = None,
|
||||||
|
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
||||||
|
custom_input_type=self.custom_input_type,
|
||||||
|
custom_output_type=self.custom_output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_listeners(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
on_start: Optional[Listener] = None,
|
||||||
|
on_end: Optional[Listener] = None,
|
||||||
|
on_error: Optional[Listener] = None,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
"""
|
||||||
|
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||||
|
|
||||||
|
on_start: Called before the runnable starts running, with the Run object.
|
||||||
|
on_end: Called after the runnable finishes running, with the Run object.
|
||||||
|
on_error: Called if the runnable throws an error, with the Run object.
|
||||||
|
|
||||||
|
The Run object contains information about the run, including its id,
|
||||||
|
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||||
|
added to the run.
|
||||||
|
"""
|
||||||
|
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=self.config,
|
||||||
|
config_factories=[
|
||||||
|
lambda config: {
|
||||||
|
"callbacks": [
|
||||||
|
RootListenersTracer(
|
||||||
|
config=config,
|
||||||
|
on_start=on_start,
|
||||||
|
on_end=on_end,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
custom_input_type=self.custom_input_type,
|
||||||
|
custom_output_type=self.custom_output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_types(
|
||||||
|
self,
|
||||||
|
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||||
|
output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||||
|
) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound,
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=self.config,
|
||||||
|
custom_input_type=input_type
|
||||||
|
if input_type is not None
|
||||||
|
else self.custom_input_type,
|
||||||
|
custom_output_type=output_type
|
||||||
|
if output_type is not None
|
||||||
|
else self.custom_output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||||
|
return self.__class__(
|
||||||
|
bound=self.bound.with_retry(**kwargs),
|
||||||
|
kwargs=self.kwargs,
|
||||||
|
config=self.config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
RunnableLike = Union[
|
RunnableLike = Union[
|
||||||
Runnable[Input, Output],
|
Runnable[Input, Output],
|
||||||
|
@@ -119,11 +119,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
for spec in step.config_specs
|
for spec in step.config_specs
|
||||||
)
|
)
|
||||||
|
|
||||||
def config_schema(
|
|
||||||
self, *, include: Optional[Sequence[str]] = None
|
|
||||||
) -> Type[BaseModel]:
|
|
||||||
return self.runnable.config_schema(include=include)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_lc_serializable(cls) -> bool:
|
def is_lc_serializable(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
|
@@ -21,7 +21,7 @@ from tenacity import (
|
|||||||
wait_exponential_jitter,
|
wait_exponential_jitter,
|
||||||
)
|
)
|
||||||
|
|
||||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
|
||||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|||||||
U = TypeVar("U")
|
U = TypeVar("U")
|
||||||
|
|
||||||
|
|
||||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||||
"""Retry a Runnable if it fails.
|
"""Retry a Runnable if it fails.
|
||||||
|
|
||||||
A RunnableRetry helps can be used to add retry logic to any object
|
A RunnableRetry helps can be used to add retry logic to any object
|
||||||
|
Reference in New Issue
Block a user