mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 07:35:18 +00:00
core[patch]: allow access RunnableWithFallbacks.runnable attrs (#22139)
RFC, candidate fix for #13095 #22134
This commit is contained in:
parent
7496fe2b16
commit
d61bdeba25
@ -1,4 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import typing
|
||||
from functools import wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@ -549,3 +552,77 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(output)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get an attribute from the wrapped runnable and its fallbacks.
|
||||
|
||||
Returns:
|
||||
If the attribute is anything other than a method that outputs a Runnable,
|
||||
returns getattr(self.runnable, name). If the attribute is a method that
|
||||
does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new
|
||||
RunnableBinding) then self.runnable and each of the runnables in
|
||||
self.fallbacks is replaced with getattr(x, name).
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
gpt_4o = ChatOpenAI(model="gpt-4o")
|
||||
claude_3_sonnet = ChatAnthropic(model="claude-3-sonnet-20240229")
|
||||
llm = gpt_4o.with_fallbacks([claude_3_sonnet])
|
||||
|
||||
llm.model_name
|
||||
# -> "gpt-4o"
|
||||
|
||||
# .bind_tools() is called on both ChatOpenAI and ChatAnthropic
|
||||
# Equivalent to:
|
||||
# gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])])
|
||||
llm.bind_tools([...])
|
||||
# -> RunnableWithFallbacks(
|
||||
runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}),
|
||||
fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})],
|
||||
)
|
||||
|
||||
""" # noqa: E501
|
||||
attr = getattr(self.runnable, name)
|
||||
if _returns_runnable(attr):
|
||||
|
||||
@wraps(attr)
|
||||
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||
new_runnable = attr(*args, **kwargs)
|
||||
new_fallbacks = []
|
||||
for fallback in self.fallbacks:
|
||||
fallback_attr = getattr(fallback, name)
|
||||
new_fallbacks.append(fallback_attr(*args, **kwargs))
|
||||
|
||||
return self.__class__(
|
||||
**{
|
||||
**self.dict(),
|
||||
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
|
||||
}
|
||||
)
|
||||
|
||||
return wrapped
|
||||
|
||||
return attr
|
||||
|
||||
|
||||
def _returns_runnable(attr: Any) -> bool:
|
||||
if not callable(attr):
|
||||
return False
|
||||
return_type = typing.get_type_hints(attr).get("return")
|
||||
return bool(return_type and _is_runnable_type(return_type))
|
||||
|
||||
|
||||
def _is_runnable_type(type_: Any) -> bool:
|
||||
if inspect.isclass(type_):
|
||||
return issubclass(type_, Runnable)
|
||||
origin = getattr(type_, "__origin__", None)
|
||||
if inspect.isclass(origin):
|
||||
return issubclass(origin, Runnable)
|
||||
elif origin is typing.Union:
|
||||
return all(_is_runnable_type(t) for t in type_.__args__)
|
||||
else:
|
||||
return False
|
||||
|
@ -1,20 +1,41 @@
|
||||
import sys
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import (
|
||||
BaseChatModel,
|
||||
FakeListLLM,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumps
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None:
|
||||
)
|
||||
async for c in runnable.astream({}):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStructuredOutputModel(BaseChatModel):
|
||||
foo: int
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
return self | (lambda x: {"foo": self.foo})
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake1"
|
||||
|
||||
|
||||
class FakeModel(BaseChatModel):
|
||||
bar: int
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake2"
|
||||
|
||||
|
||||
def test_fallbacks_getattr() -> None:
|
||||
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||
[FakeModel(bar=4)]
|
||||
)
|
||||
assert llm_with_fallbacks.foo == 3
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
llm_with_fallbacks.bar
|
||||
|
||||
|
||||
def test_fallbacks_getattr_runnable_output() -> None:
|
||||
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||
[FakeModel(bar=4)]
|
||||
)
|
||||
llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([])
|
||||
assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks)
|
||||
assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding)
|
||||
assert all(
|
||||
isinstance(fallback, RunnableBinding)
|
||||
for fallback in llm_with_fallbacks_with_tools.fallbacks
|
||||
)
|
||||
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
llm_with_fallbacks.with_structured_output({})
|
||||
|
Loading…
Reference in New Issue
Block a user