Update method and docstring

This commit is contained in:
Nuno Campos 2023-08-18 16:18:04 +01:00
parent 93bbf67afc
commit 9777c2801d
3 changed files with 27 additions and 10 deletions

View File

@ -217,9 +217,10 @@ class Runnable(Generic[Input, Output], ABC):
""" """
return RunnableBinding(bound=self, kwargs=kwargs) return RunnableBinding(bound=self, kwargs=kwargs)
def each(self) -> Runnable[List[Input], List[Output]]: def map(self) -> Runnable[List[Input], List[Output]]:
""" """
Wrap a Runnable to run it on each element of the input sequence. Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
""" """
return RunnableEach(bound=self) return RunnableEach(bound=self)
@ -1384,7 +1385,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
def lc_namespace(self) -> List[str]: def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1] return self.__class__.__module__.split(".")[:-1]
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override] def map(self) -> RunnableEach[Input, Output]: # type: ignore[override]
return self return self
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:

View File

@ -563,10 +563,9 @@
"lc": 1, "lc": 1,
"type": "constructor", "type": "constructor",
"id": [ "id": [
"langchain", "runnable",
"output_parsers", "test_runnable",
"list", "FakeSplitIntoListParser"
"CommaSeparatedListOutputParser"
], ],
"kwargs": {} "kwargs": {}
} }

View File

@ -21,7 +21,6 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
SystemMessagePromptTemplate, SystemMessagePromptTemplate,
) )
from langchain.pydantic_v1 import ValidationError
from langchain.schema.document import Document from langchain.schema.document import Document
from langchain.schema.messages import ( from langchain.schema.messages import (
AIMessage, AIMessage,
@ -29,7 +28,7 @@ from langchain.schema.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.output_parser import StrOutputParser from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
from langchain.schema.retriever import BaseRetriever from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import ( from langchain.schema.runnable import (
RouterRunnable, RouterRunnable,
@ -1090,6 +1089,24 @@ async def test_llm_with_fallbacks(
assert dumps(runnable, pretty=True) == snapshot assert dumps(runnable, pretty=True) == snapshot
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a comma-separated list."""
@property
def lc_serializable(self) -> bool:
return True
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
def test_each(snapshot: SnapshotAssertion) -> None: def test_each(snapshot: SnapshotAssertion) -> None:
prompt = ( prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.") SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -1098,7 +1115,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"]) first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"]) second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each() chain = prompt | first_llm | FakeSplitIntoListParser() | second_llm.map()
assert dumps(chain, pretty=True) == snapshot assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"}) output = chain.invoke({"question": "What up"})