diff --git a/libs/core/langchain_core/runnables/fallbacks.py b/libs/core/langchain_core/runnables/fallbacks.py index 494c27a4bcd..73e21593c26 100644 --- a/libs/core/langchain_core/runnables/fallbacks.py +++ b/libs/core/langchain_core/runnables/fallbacks.py @@ -1,4 +1,7 @@ import asyncio +import inspect +import typing +from functools import wraps from typing import ( TYPE_CHECKING, Any, @@ -549,3 +552,77 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]): await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(output) + + def __getattr__(self, name: str) -> Any: + """Get an attribute from the wrapped runnable and its fallbacks. + + Returns: + If the attribute is anything other than a method that outputs a Runnable, + returns getattr(self.runnable, name). If the attribute is a method that + does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new + RunnableBinding) then self.runnable and each of the runnables in + self.fallbacks is replaced with getattr(x, name). + + Example: + .. code-block:: python + + from langchain_openai import ChatOpenAI + from langchain_anthropic import ChatAnthropic + + gpt_4o = ChatOpenAI(model="gpt-4o") + claude_3_sonnet = ChatAnthropic(model="claude-3-sonnet-20240229") + llm = gpt_4o.with_fallbacks([claude_3_sonnet]) + + llm.model_name + # -> "gpt-4o" + + # .bind_tools() is called on both ChatOpenAI and ChatAnthropic + # Equivalent to: + # gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])]) + llm.bind_tools([...]) + # -> RunnableWithFallbacks( + runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}), + fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})], + ) + + """ # noqa: E501 + attr = getattr(self.runnable, name) + if _returns_runnable(attr): + + @wraps(attr) + def wrapped(*args: Any, **kwargs: Any) -> Any: + new_runnable = attr(*args, **kwargs) + new_fallbacks = [] + for fallback in self.fallbacks: + fallback_attr = getattr(fallback, name) + new_fallbacks.append(fallback_attr(*args, **kwargs)) + + return self.__class__( + **{ + **self.dict(), + **{"runnable": new_runnable, "fallbacks": new_fallbacks}, + } + ) + + return wrapped + + return attr + + +def _returns_runnable(attr: Any) -> bool: + if not callable(attr): + return False + return_type = typing.get_type_hints(attr).get("return") + return bool(return_type and _is_runnable_type(return_type)) + + +def _is_runnable_type(type_: Any) -> bool: + if inspect.isclass(type_): + return issubclass(type_, Runnable) + origin = getattr(type_, "__origin__", None) + if inspect.isclass(origin): + return issubclass(origin, Runnable) + elif origin is typing.Union: + return all(_is_runnable_type(t) for t in type_.__args__) + else: + return False diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 2dbf213b7bc..ba9091a1902 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -1,20 +1,41 @@ import sys -from typing import Any, AsyncIterator, Iterator +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, +) import pytest from syrupy import SnapshotAssertion -from langchain_core.language_models import FakeListLLM +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import ( + BaseChatModel, + FakeListLLM, + LanguageModelInput, +) from langchain_core.load import dumps +from langchain_core.messages import BaseMessage +from langchain_core.outputs import ChatResult from langchain_core.prompts import PromptTemplate +from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import ( Runnable, + RunnableBinding, RunnableGenerator, RunnableLambda, RunnableParallel, RunnablePassthrough, RunnableWithFallbacks, ) +from langchain_core.tools import BaseTool @pytest.fixture() @@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None: ) async for c in runnable.astream({}): pass + + +class FakeStructuredOutputModel(BaseChatModel): + foo: int + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return ChatResult(generations=[]) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + return self.bind(tools=tools) + + def with_structured_output( + self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any + ) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]: + return self | (lambda x: {"foo": self.foo}) + + @property + def _llm_type(self) -> str: + return "fake1" + + +class FakeModel(BaseChatModel): + bar: int + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Top Level call""" + return ChatResult(generations=[]) + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + return self.bind(tools=tools) + + @property + def _llm_type(self) -> str: + return "fake2" + + +def test_fallbacks_getattr() -> None: + llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks( + [FakeModel(bar=4)] + ) + assert llm_with_fallbacks.foo == 3 + + with pytest.raises(AttributeError): + llm_with_fallbacks.bar + + +def test_fallbacks_getattr_runnable_output() -> None: + llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks( + [FakeModel(bar=4)] + ) + llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([]) + assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks) + assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding) + assert all( + isinstance(fallback, RunnableBinding) + for fallback in llm_with_fallbacks_with_tools.fallbacks + ) + assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == [] + + with pytest.raises(NotImplementedError): + llm_with_fallbacks.with_structured_output({})