mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core[patch]: allow access RunnableWithFallbacks.runnable attrs (#22139)
RFC, candidate fix for #13095 #22134
This commit is contained in:
parent
7496fe2b16
commit
d61bdeba25
@ -1,4 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import typing
|
||||||
|
from functools import wraps
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
@ -549,3 +552,77 @@ class RunnableWithFallbacks(RunnableSerializable[Input, Output]):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
await run_manager.on_chain_end(output)
|
await run_manager.on_chain_end(output)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Get an attribute from the wrapped runnable and its fallbacks.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
If the attribute is anything other than a method that outputs a Runnable,
|
||||||
|
returns getattr(self.runnable, name). If the attribute is a method that
|
||||||
|
does return a new Runnable (e.g. llm.bind_tools([...]) outputs a new
|
||||||
|
RunnableBinding) then self.runnable and each of the runnables in
|
||||||
|
self.fallbacks is replaced with getattr(x, name).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
|
||||||
|
gpt_4o = ChatOpenAI(model="gpt-4o")
|
||||||
|
claude_3_sonnet = ChatAnthropic(model="claude-3-sonnet-20240229")
|
||||||
|
llm = gpt_4o.with_fallbacks([claude_3_sonnet])
|
||||||
|
|
||||||
|
llm.model_name
|
||||||
|
# -> "gpt-4o"
|
||||||
|
|
||||||
|
# .bind_tools() is called on both ChatOpenAI and ChatAnthropic
|
||||||
|
# Equivalent to:
|
||||||
|
# gpt_4o.bind_tools([...]).with_fallbacks([claude_3_sonnet.bind_tools([...])])
|
||||||
|
llm.bind_tools([...])
|
||||||
|
# -> RunnableWithFallbacks(
|
||||||
|
runnable=RunnableBinding(bound=ChatOpenAI(...), kwargs={"tools": [...]}),
|
||||||
|
fallbacks=[RunnableBinding(bound=ChatAnthropic(...), kwargs={"tools": [...]})],
|
||||||
|
)
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
attr = getattr(self.runnable, name)
|
||||||
|
if _returns_runnable(attr):
|
||||||
|
|
||||||
|
@wraps(attr)
|
||||||
|
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
new_runnable = attr(*args, **kwargs)
|
||||||
|
new_fallbacks = []
|
||||||
|
for fallback in self.fallbacks:
|
||||||
|
fallback_attr = getattr(fallback, name)
|
||||||
|
new_fallbacks.append(fallback_attr(*args, **kwargs))
|
||||||
|
|
||||||
|
return self.__class__(
|
||||||
|
**{
|
||||||
|
**self.dict(),
|
||||||
|
**{"runnable": new_runnable, "fallbacks": new_fallbacks},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
return attr
|
||||||
|
|
||||||
|
|
||||||
|
def _returns_runnable(attr: Any) -> bool:
|
||||||
|
if not callable(attr):
|
||||||
|
return False
|
||||||
|
return_type = typing.get_type_hints(attr).get("return")
|
||||||
|
return bool(return_type and _is_runnable_type(return_type))
|
||||||
|
|
||||||
|
|
||||||
|
def _is_runnable_type(type_: Any) -> bool:
|
||||||
|
if inspect.isclass(type_):
|
||||||
|
return issubclass(type_, Runnable)
|
||||||
|
origin = getattr(type_, "__origin__", None)
|
||||||
|
if inspect.isclass(origin):
|
||||||
|
return issubclass(origin, Runnable)
|
||||||
|
elif origin is typing.Union:
|
||||||
|
return all(_is_runnable_type(t) for t in type_.__args__)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
@ -1,20 +1,41 @@
|
|||||||
import sys
|
import sys
|
||||||
from typing import Any, AsyncIterator, Iterator
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy import SnapshotAssertion
|
from syrupy import SnapshotAssertion
|
||||||
|
|
||||||
from langchain_core.language_models import FakeListLLM
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||||
|
from langchain_core.language_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
FakeListLLM,
|
||||||
|
LanguageModelInput,
|
||||||
|
)
|
||||||
from langchain_core.load import dumps
|
from langchain_core.load import dumps
|
||||||
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.outputs import ChatResult
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
from langchain_core.runnables import (
|
from langchain_core.runnables import (
|
||||||
Runnable,
|
Runnable,
|
||||||
|
RunnableBinding,
|
||||||
RunnableGenerator,
|
RunnableGenerator,
|
||||||
RunnableLambda,
|
RunnableLambda,
|
||||||
RunnableParallel,
|
RunnableParallel,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnableWithFallbacks,
|
RunnableWithFallbacks,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None:
|
|||||||
)
|
)
|
||||||
async for c in runnable.astream({}):
|
async for c in runnable.astream({}):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FakeStructuredOutputModel(BaseChatModel):
|
||||||
|
foo: int
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
return ChatResult(generations=[])
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
return self.bind(tools=tools)
|
||||||
|
|
||||||
|
def with_structured_output(
|
||||||
|
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||||
|
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||||
|
return self | (lambda x: {"foo": self.foo})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake1"
|
||||||
|
|
||||||
|
|
||||||
|
class FakeModel(BaseChatModel):
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Top Level call"""
|
||||||
|
return ChatResult(generations=[])
|
||||||
|
|
||||||
|
def bind_tools(
|
||||||
|
self,
|
||||||
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||||
|
return self.bind(tools=tools)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "fake2"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallbacks_getattr() -> None:
|
||||||
|
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||||
|
[FakeModel(bar=4)]
|
||||||
|
)
|
||||||
|
assert llm_with_fallbacks.foo == 3
|
||||||
|
|
||||||
|
with pytest.raises(AttributeError):
|
||||||
|
llm_with_fallbacks.bar
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallbacks_getattr_runnable_output() -> None:
|
||||||
|
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||||
|
[FakeModel(bar=4)]
|
||||||
|
)
|
||||||
|
llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([])
|
||||||
|
assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks)
|
||||||
|
assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding)
|
||||||
|
assert all(
|
||||||
|
isinstance(fallback, RunnableBinding)
|
||||||
|
for fallback in llm_with_fallbacks_with_tools.fallbacks
|
||||||
|
)
|
||||||
|
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
llm_with_fallbacks.with_structured_output({})
|
||||||
|
Loading…
Reference in New Issue
Block a user