mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 11:39:03 +00:00
WIP
Add test Add test Lint
This commit is contained in:
parent
c184be5511
commit
93bbf67afc
@ -217,6 +217,12 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||||
|
|
||||||
|
def each(self) -> Runnable[List[Input], List[Output]]:
|
||||||
|
"""
|
||||||
|
Wrap a Runnable to run it on each element of the input sequence.
|
||||||
|
"""
|
||||||
|
return RunnableEach(bound=self)
|
||||||
|
|
||||||
def with_fallbacks(
|
def with_fallbacks(
|
||||||
self,
|
self,
|
||||||
fallbacks: Sequence[Runnable[Input, Output]],
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
@ -1360,6 +1366,41 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return self._call_with_config(self.func, input, config)
|
return self._call_with_config(self.func, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||||
|
"""
|
||||||
|
A runnable that delegates calls to another runnable with each element of the input sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bound: Runnable[Input, Output]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
|
||||||
|
return self
|
||||||
|
|
||||||
|
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||||
|
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> List[Output]:
|
||||||
|
return self.bound.batch(input, config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> List[Output]:
|
||||||
|
return await self.bound.abatch(input, config)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
|
File diff suppressed because one or more lines are too long
@ -1,5 +1,6 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from xml.dom import ValidationErr
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
@ -20,6 +21,7 @@ from langchain.prompts.chat import (
|
|||||||
HumanMessagePromptTemplate,
|
HumanMessagePromptTemplate,
|
||||||
SystemMessagePromptTemplate,
|
SystemMessagePromptTemplate,
|
||||||
)
|
)
|
||||||
|
from langchain.pydantic_v1 import ValidationError
|
||||||
from langchain.schema.document import Document
|
from langchain.schema.document import Document
|
||||||
from langchain.schema.messages import (
|
from langchain.schema.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
@ -1086,3 +1088,18 @@ async def test_llm_with_fallbacks(
|
|||||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
assert dumps(runnable, pretty=True) == snapshot
|
assert dumps(runnable, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_each(snapshot: SnapshotAssertion) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
|
||||||
|
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
|
||||||
|
|
||||||
|
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
|
||||||
|
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
output = chain.invoke({"question": "What up"})
|
||||||
|
assert output == ["this", "is", "a"]
|
||||||
|
Loading…
Reference in New Issue
Block a user