Compare commits

...

3 Commits

Author SHA1 Message Date
Bagatur
2eead4cd5d fix 2024-01-29 16:03:41 -08:00
Bagatur
45fc51ea5f fmt 2024-01-29 16:02:28 -08:00
Bagatur
1bbac10852 rfc: bind collision resolution 2024-01-29 16:00:19 -08:00
2 changed files with 44 additions and 19 deletions

View File

@@ -3967,6 +3967,8 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
f"Configurable key '{key}' not found in runnable with"
f" config keys: {allowed_keys}"
)
if kwargs:
kwargs.pop("__on_collision__", None)
super().__init__(
bound=bound,
kwargs=kwargs or {},
@@ -3977,6 +3979,47 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
**other_kwargs,
)
def bind(
self,
*,
__on_collision__: Literal["overwrite", "add"] = "overwrite",
**kwargs: Any,
) -> Runnable[Input, Output]:
"""Bind additional kwargs to a Runnable, returning a new Runnable.
Args:
__on_collision__: How to handle overlapping keys in self.kwargs and kwargs.
If "overwrite", then for overlapping keys the value from kwargs
overwrites self.kwargs. If "add" then addition operation is attempted
between self.kwargs[key] + kwargs[key].
**kwargs: The kwargs to bind to the Runnable.
Returns:
A new Runnable with the same type and config as the original,
but with the additional kwargs bound.
"""
if __on_collision__ == "overwrite":
_kwargs = {**self.kwargs, **kwargs}
elif __on_collision__ == "add":
_kwargs = {**self.kwargs}
for k, v in kwargs.items():
if k in _kwargs:
_kwargs[k] = _kwargs[k] + v
else:
_kwargs[k] = v
else:
raise ValueError(
f"Unknown value for argument {__on_collision__=}. Expected one of "
f"'overwrite', 'add'."
)
return self.__class__(
bound=self.bound,
config=self.config,
kwargs=_kwargs,
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
@@ -4218,24 +4261,6 @@ class RunnableBinding(RunnableBindingBase[Input, Output]):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "runnable"]
def bind(self, **kwargs: Any) -> Runnable[Input, Output]:
"""Bind additional kwargs to a Runnable, returning a new Runnable.
Args:
**kwargs: The kwargs to bind to the Runnable.
Returns:
A new Runnable with the same type and config as the original,
but with the additional kwargs bound.
"""
return self.__class__(
bound=self.bound,
config=self.config,
kwargs={**self.kwargs, **kwargs},
custom_input_type=self.custom_input_type,
custom_output_type=self.custom_output_type,
)
def with_config(
self,
config: Optional[RunnableConfig] = None,

View File

@@ -100,7 +100,7 @@ def create_react_agent(
tools=render_text_description(list(tools)),
tool_names=", ".join([t.name for t in tools]),
)
llm_with_stop = llm.bind(stop=["\nObservation"])
llm_with_stop = llm.bind(stop=["\nObservation"], __on_collision__="add")
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_log_to_str(x["intermediate_steps"]),