core[patch]: allow access RunnableWithFallbacks.runnable attrs (#22139)

RFC, candidate fix for #13095 #22134
This commit is contained in:
Bagatur 2024-05-28 13:18:09 -07:00 committed by GitHub
parent 7496fe2b16
commit d61bdeba25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 182 additions and 2 deletions

View File

@ -1,4 +1,7 @@
import asyncio
import inspect
import typing
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
@ -549,3 +552,77 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(output)
def __getattr__(self, name: str) -> Any:
"""Get an attribute from the wrapped runnable and its fallbacks.
Returns:
If the attribute is anything other than a method that outputs a Runnable,
returns getattr(self.runnable, name). If the attribute is a method that
does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new
RunnableBinding) then self.runnable and each of the runnables in
self.fallbacks is replaced with getattr(x, name).
Example:
.. code-block:: python
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
gpt_4o = ChatOpenAI(model="gpt-4o")
claude_3_sonnet = ChatAnthropic(model="claude-3-sonnet-20240229")
llm = gpt_4o.with_fallbacks([claude_3_sonnet])
llm.model_name
# -> "gpt-4o"
# .bind_tools() is called on both ChatOpenAI and ChatAnthropic
# Equivalent to:
# gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])])
llm.bind_tools([...])
# -> RunnableWithFallbacks(
runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}),
fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})],
)
""" # noqa: E501
attr = getattr(self.runnable, name)
if _returns_runnable(attr):
@wraps(attr)
def wrapped(*args: Any, **kwargs: Any) -> Any:
new_runnable = attr(*args, **kwargs)
new_fallbacks = []
for fallback in self.fallbacks:
fallback_attr = getattr(fallback, name)
new_fallbacks.append(fallback_attr(*args, **kwargs))
return self.__class__(
**{
**self.dict(),
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
}
)
return wrapped
return attr
def _returns_runnable(attr: Any) -> bool:
if not callable(attr):
return False
return_type = typing.get_type_hints(attr).get("return")
return bool(return_type and _is_runnable_type(return_type))
def _is_runnable_type(type_: Any) -> bool:
if inspect.isclass(type_):
return issubclass(type_, Runnable)
origin = getattr(type_, "__origin__", None)
if inspect.isclass(origin):
return issubclass(origin, Runnable)
elif origin is typing.Union:
return all(_is_runnable_type(t) for t in type_.__args__)
else:
return False

View File

@ -1,20 +1,41 @@
import sys
from typing import Any, AsyncIterator, Iterator
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
)
import pytest
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import (
BaseChatModel,
FakeListLLM,
LanguageModelInput,
)
from langchain_core.load import dumps
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableWithFallbacks,
)
from langchain_core.tools import BaseTool
@pytest.fixture()
@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None:
)
async for c in runnable.astream({}):
pass
class FakeStructuredOutputModel(BaseChatModel):
foo: int
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
return self | (lambda x: {"foo": self.foo})
@property
def _llm_type(self) -> str:
return "fake1"
class FakeModel(BaseChatModel):
bar: int
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
@property
def _llm_type(self) -> str:
return "fake2"
def test_fallbacks_getattr() -> None:
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
[FakeModel(bar=4)]
)
assert llm_with_fallbacks.foo == 3
with pytest.raises(AttributeError):
llm_with_fallbacks.bar
def test_fallbacks_getattr_runnable_output() -> None:
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
[FakeModel(bar=4)]
)
llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([])
assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks)
assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding)
assert all(
isinstance(fallback, RunnableBinding)
for fallback in llm_with_fallbacks_with_tools.fallbacks
)
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
with pytest.raises(NotImplementedError):
llm_with_fallbacks.with_structured_output({})