core[patch]: allow access RunnableWithFallbacks.runnable attrs (#22139)

RFC, candidate fix for #13095 #22134
This commit is contained in:
Bagatur
2024-05-28 13:18:09 -07:00
committed by GitHub
parent 7496fe2b16
commit d61bdeba25
2 changed files with 182 additions and 2 deletions

View File

@@ -1,20 +1,41 @@
import sys
from typing import Any, AsyncIterator, Iterator
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
)
import pytest
from syrupy import SnapshotAssertion
from langchain_core.language_models import FakeListLLM
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import (
BaseChatModel,
FakeListLLM,
LanguageModelInput,
)
from langchain_core.load import dumps
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from langchain_core.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableGenerator,
RunnableLambda,
RunnableParallel,
RunnablePassthrough,
RunnableWithFallbacks,
)
from langchain_core.tools import BaseTool
@pytest.fixture()
@@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None:
)
async for c in runnable.astream({}):
pass
class FakeStructuredOutputModel(BaseChatModel):
foo: int
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
return self | (lambda x: {"foo": self.foo})
@property
def _llm_type(self) -> str:
return "fake1"
class FakeModel(BaseChatModel):
bar: int
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
return ChatResult(generations=[])
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self.bind(tools=tools)
@property
def _llm_type(self) -> str:
return "fake2"
def test_fallbacks_getattr() -> None:
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
[FakeModel(bar=4)]
)
assert llm_with_fallbacks.foo == 3
with pytest.raises(AttributeError):
llm_with_fallbacks.bar
def test_fallbacks_getattr_runnable_output() -> None:
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
[FakeModel(bar=4)]
)
llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([])
assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks)
assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding)
assert all(
isinstance(fallback, RunnableBinding)
for fallback in llm_with_fallbacks_with_tools.fallbacks
)
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
with pytest.raises(NotImplementedError):
llm_with_fallbacks.with_structured_output({})