mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-16 06:53:16 +00:00
core[patch]: allow access RunnableWithFallbacks.runnable attrs (#22139)
RFC, candidate fix for #13095 #22134
This commit is contained in:
@@ -1,20 +1,41 @@
|
||||
import sys
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pytest
|
||||
from syrupy import SnapshotAssertion
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import (
|
||||
BaseChatModel,
|
||||
FakeListLLM,
|
||||
LanguageModelInput,
|
||||
)
|
||||
from langchain_core.load import dumps
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatResult
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableBinding,
|
||||
RunnableGenerator,
|
||||
RunnableLambda,
|
||||
RunnableParallel,
|
||||
RunnablePassthrough,
|
||||
RunnableWithFallbacks,
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -288,3 +309,85 @@ async def test_fallbacks_astream() -> None:
|
||||
)
|
||||
async for c in runnable.astream({}):
|
||||
pass
|
||||
|
||||
|
||||
class FakeStructuredOutputModel(BaseChatModel):
|
||||
foo: int
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
||||
return self | (lambda x: {"foo": self.foo})
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake1"
|
||||
|
||||
|
||||
class FakeModel(BaseChatModel):
|
||||
bar: int
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Top Level call"""
|
||||
return ChatResult(generations=[])
|
||||
|
||||
def bind_tools(
|
||||
self,
|
||||
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
||||
**kwargs: Any,
|
||||
) -> Runnable[LanguageModelInput, BaseMessage]:
|
||||
return self.bind(tools=tools)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "fake2"
|
||||
|
||||
|
||||
def test_fallbacks_getattr() -> None:
|
||||
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||
[FakeModel(bar=4)]
|
||||
)
|
||||
assert llm_with_fallbacks.foo == 3
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
llm_with_fallbacks.bar
|
||||
|
||||
|
||||
def test_fallbacks_getattr_runnable_output() -> None:
|
||||
llm_with_fallbacks = FakeStructuredOutputModel(foo=3).with_fallbacks(
|
||||
[FakeModel(bar=4)]
|
||||
)
|
||||
llm_with_fallbacks_with_tools = llm_with_fallbacks.bind_tools([])
|
||||
assert isinstance(llm_with_fallbacks_with_tools, RunnableWithFallbacks)
|
||||
assert isinstance(llm_with_fallbacks_with_tools.runnable, RunnableBinding)
|
||||
assert all(
|
||||
isinstance(fallback, RunnableBinding)
|
||||
for fallback in llm_with_fallbacks_with_tools.fallbacks
|
||||
)
|
||||
assert llm_with_fallbacks_with_tools.runnable.kwargs["tools"] == []
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
llm_with_fallbacks.with_structured_output({})
|
||||
|
Reference in New Issue
Block a user