Update method and docstring

This commit is contained in:
Nuno Campos 2023-08-18 16:18:04 +01:00
parent 93bbf67afc
commit 9777c2801d
3 changed files with 27 additions and 10 deletions

View File

@ -217,9 +217,10 @@ class Runnable(Generic[Input, Output], ABC):
"""
return RunnableBinding(bound=self, kwargs=kwargs)
def each(self) -> Runnable[List[Input], List[Output]]:
def map(self) -> Runnable[List[Input], List[Output]]:
"""
Wrap a Runnable to run it on each element of the input sequence.
Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
"""
return RunnableEach(bound=self)
@ -1384,7 +1385,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
def map(self) -> RunnableEach[Input, Output]: # type: ignore[override]
return self
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:

View File

@ -563,10 +563,9 @@
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"output_parsers",
"list",
"CommaSeparatedListOutputParser"
"runnable",
"test_runnable",
"FakeSplitIntoListParser"
],
"kwargs": {}
}

View File

@ -21,7 +21,6 @@ from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.pydantic_v1 import ValidationError
from langchain.schema.document import Document
from langchain.schema.messages import (
AIMessage,
@ -29,7 +28,7 @@ from langchain.schema.messages import (
HumanMessage,
SystemMessage,
)
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
RouterRunnable,
@ -1090,6 +1089,24 @@ async def test_llm_with_fallbacks(
assert dumps(runnable, pretty=True) == snapshot
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a comma-separated list."""
@property
def lc_serializable(self) -> bool:
return True
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
def test_each(snapshot: SnapshotAssertion) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
@ -1098,7 +1115,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
chain = prompt | first_llm | FakeSplitIntoListParser() | second_llm.map()
assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"})