mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-18 16:16:33 +00:00
Make it easier to subclass runnable binding with custom init args (#13189)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
|
||||
|
||||
|
||||
class HubRunnable(RunnableBinding[Input, Output]):
|
||||
class HubRunnable(RunnableBindingBase[Input, Output]):
|
||||
"""
|
||||
An instance of a runnable stored in the LangChain Hub.
|
||||
"""
|
||||
|
@@ -5,7 +5,8 @@ from typing_extensions import TypedDict
|
||||
|
||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable, RunnableBinding
|
||||
from langchain.schema.runnable import RouterRunnable, Runnable
|
||||
from langchain.schema.runnable.base import RunnableBindingBase
|
||||
|
||||
|
||||
class OpenAIFunction(TypedDict):
|
||||
@@ -19,7 +20,7 @@ class OpenAIFunction(TypedDict):
|
||||
"""The parameters to the function."""
|
||||
|
||||
|
||||
class OpenAIFunctionsRouter(RunnableBinding[BaseMessage, Any]):
|
||||
class OpenAIFunctionsRouter(RunnableBindingBase[BaseMessage, Any]):
|
||||
"""A runnable that routes to the selected function."""
|
||||
|
||||
functions: Optional[List[OpenAIFunction]]
|
||||
|
@@ -2581,11 +2581,6 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@@ -2659,7 +2654,7 @@ class RunnableEach(RunnableSerializable[List[Input], List[Output]]):
|
||||
return await self._acall_with_config(self._ainvoke, input, config, **kwargs)
|
||||
|
||||
|
||||
class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
"""
|
||||
@@ -2749,11 +2744,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
|
||||
return self.bound.config_specs
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.bound.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
@@ -2762,93 +2752,6 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return cls.__module__.split(".")[:-1]
|
||||
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
config=self.config,
|
||||
kwargs={**self.kwargs, **kwargs},
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Listener] = None,
|
||||
on_end: Optional[Listener] = None,
|
||||
on_error: Optional[Listener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called before the runnable starts running, with the Run object.
|
||||
on_end: Called after the runnable finishes running, with the Run object.
|
||||
on_error: Called if the runnable throws an error, with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
||||
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
config_factories=[
|
||||
lambda config: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
custom_input_type=input_type
|
||||
if input_type is not None
|
||||
else self.custom_input_type,
|
||||
custom_output_type=output_type
|
||||
if output_type is not None
|
||||
else self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound.with_retry(**kwargs),
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
def _merge_configs(self, *configs: Optional[RunnableConfig]) -> RunnableConfig:
|
||||
config = merge_configs(self.config, *configs)
|
||||
return merge_configs(config, *(f(config) for f in self.config_factories))
|
||||
@@ -2972,7 +2875,97 @@ class RunnableBinding(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
RunnableBinding.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
RunnableBindingBase.update_forward_refs(RunnableConfig=RunnableConfig)
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
config=self.config,
|
||||
kwargs={**self.kwargs, **kwargs},
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_config(
|
||||
self,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
# Sadly Unpack is not well supported by mypy so this will have to be untyped
|
||||
**kwargs: Any,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=cast(RunnableConfig, {**self.config, **(config or {}), **kwargs}),
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_listeners(
|
||||
self,
|
||||
*,
|
||||
on_start: Optional[Listener] = None,
|
||||
on_end: Optional[Listener] = None,
|
||||
on_error: Optional[Listener] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
"""
|
||||
Bind lifecycle listeners to a Runnable, returning a new Runnable.
|
||||
|
||||
on_start: Called before the runnable starts running, with the Run object.
|
||||
on_end: Called after the runnable finishes running, with the Run object.
|
||||
on_error: Called if the runnable throws an error, with the Run object.
|
||||
|
||||
The Run object contains information about the run, including its id,
|
||||
type, input, output, error, start_time, end_time, and any tags or metadata
|
||||
added to the run.
|
||||
"""
|
||||
from langchain.callbacks.tracers.root_listeners import RootListenersTracer
|
||||
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
config_factories=[
|
||||
lambda config: {
|
||||
"callbacks": [
|
||||
RootListenersTracer(
|
||||
config=config,
|
||||
on_start=on_start,
|
||||
on_end=on_end,
|
||||
on_error=on_error,
|
||||
)
|
||||
],
|
||||
}
|
||||
],
|
||||
custom_input_type=self.custom_input_type,
|
||||
custom_output_type=self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_types(
|
||||
self,
|
||||
input_type: Optional[Union[Type[Input], BaseModel]] = None,
|
||||
output_type: Optional[Union[Type[Output], BaseModel]] = None,
|
||||
) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound,
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
custom_input_type=input_type
|
||||
if input_type is not None
|
||||
else self.custom_input_type,
|
||||
custom_output_type=output_type
|
||||
if output_type is not None
|
||||
else self.custom_output_type,
|
||||
)
|
||||
|
||||
def with_retry(self, **kwargs: Any) -> Runnable[Input, Output]:
|
||||
return self.__class__(
|
||||
bound=self.bound.with_retry(**kwargs),
|
||||
kwargs=self.kwargs,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
|
||||
RunnableLike = Union[
|
||||
Runnable[Input, Output],
|
||||
|
@@ -119,11 +119,6 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
for spec in step.config_specs
|
||||
)
|
||||
|
||||
def config_schema(
|
||||
self, *, include: Optional[Sequence[str]] = None
|
||||
) -> Type[BaseModel]:
|
||||
return self.runnable.config_schema(include=include)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return True
|
||||
|
@@ -21,7 +21,7 @@ from tenacity import (
|
||||
wait_exponential_jitter,
|
||||
)
|
||||
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBinding
|
||||
from langchain.schema.runnable.base import Input, Output, RunnableBindingBase
|
||||
from langchain.schema.runnable.config import RunnableConfig, patch_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
||||
U = TypeVar("U")
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBinding[Input, Output]):
|
||||
class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
"""Retry a Runnable if it fails.
|
||||
|
||||
A RunnableRetry helps can be used to add retry logic to any object
|
||||
|
Reference in New Issue
Block a user