Add test

Add test

Lint
This commit is contained in:
Nuno Campos 2023-08-18 14:08:54 +01:00
parent c184be5511
commit 93bbf67afc
3 changed files with 191 additions and 0 deletions

View File

@ -217,6 +217,12 @@ class Runnable(Generic[Input, Output], ABC):
""" """
return RunnableBinding(bound=self, kwargs=kwargs) return RunnableBinding(bound=self, kwargs=kwargs)
def each(self) -> Runnable[List[Input], List[Output]]:
"""
Wrap a Runnable to run it on each element of the input sequence.
"""
return RunnableEach(bound=self)
def with_fallbacks( def with_fallbacks(
self, self,
fallbacks: Sequence[Runnable[Input, Output]], fallbacks: Sequence[Runnable[Input, Output]],
@ -1360,6 +1366,41 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(self.func, input, config) return self._call_with_config(self.func, input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
"""
A runnable that delegates calls to another runnable with each element of the input sequence.
"""
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
return self
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None
) -> List[Output]:
return self.bound.batch(input, config)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None
) -> List[Output]:
return await self.bound.abatch(input, config)
class RunnableBinding(Serializable, Runnable[Input, Output]): class RunnableBinding(Serializable, Runnable[Input, Output]):
""" """
A runnable that delegates calls to another runnable with a set of kwargs. A runnable that delegates calls to another runnable with a set of kwargs.

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from uuid import UUID from uuid import UUID
from xml.dom import ValidationErr
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
@ -20,6 +21,7 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.pydantic_v1 import ValidationError
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -1086,3 +1088,18 @@ async def test_llm_with_fallbacks(
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar") assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot assert dumps(runnable, pretty=True) == snapshot
def test_each(snapshot: SnapshotAssertion) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"})
assert output == ["this", "is", "a"]