diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 7631c54dd30..b477ecdaada 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -217,9 +217,10 @@ class Runnable(Generic[Input, Output], ABC): """ return RunnableBinding(bound=self, kwargs=kwargs) - def each(self) -> Runnable[List[Input], List[Output]]: + def map(self) -> Runnable[List[Input], List[Output]]: """ - Wrap a Runnable to run it on each element of the input sequence. + Return a new Runnable that maps a list of inputs to a list of outputs, + by calling invoke() with each input. """ return RunnableEach(bound=self) @@ -1384,7 +1385,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]): def lc_namespace(self) -> List[str]: return self.__class__.__module__.split(".")[:-1] - def each(self) -> RunnableEach[Input, Output]: # type: ignore[override] + def map(self) -> RunnableEach[Input, Output]: # type: ignore[override] return self def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]: diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 44793ce3a03..fe06e568f3c 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -563,10 +563,9 @@ "lc": 1, "type": "constructor", "id": [ - "langchain", - "output_parsers", - "list", - "CommaSeparatedListOutputParser" + "runnable", + "test_runnable", + "FakeSplitIntoListParser" ], "kwargs": {} } diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index cb5c0147250..c6630d7a6d6 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -21,7 +21,6 @@ from langchain.prompts.chat import ( HumanMessagePromptTemplate, SystemMessagePromptTemplate, ) -from langchain.pydantic_v1 import ValidationError from langchain.schema.document import Document from langchain.schema.messages import ( AIMessage, @@ -29,7 +28,7 @@ from langchain.schema.messages import ( HumanMessage, SystemMessage, ) -from langchain.schema.output_parser import StrOutputParser +from langchain.schema.output_parser import BaseOutputParser, StrOutputParser from langchain.schema.retriever import BaseRetriever from langchain.schema.runnable import ( RouterRunnable, @@ -1090,6 +1089,24 @@ async def test_llm_with_fallbacks( assert dumps(runnable, pretty=True) == snapshot +class FakeSplitIntoListParser(BaseOutputParser[List[str]]): + """Parse the output of an LLM call to a comma-separated list.""" + + @property + def lc_serializable(self) -> bool: + return True + + def get_format_instructions(self) -> str: + return ( + "Your response should be a list of comma separated values, " + "eg: `foo, bar, baz`" + ) + + def parse(self, text: str) -> List[str]: + """Parse the output of an LLM call.""" + return text.strip().split(", ") + + def test_each(snapshot: SnapshotAssertion) -> None: prompt = ( SystemMessagePromptTemplate.from_template("You are a nice assistant.") @@ -1098,7 +1115,7 @@ def test_each(snapshot: SnapshotAssertion) -> None: first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"]) second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"]) - chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each() + chain = prompt | first_llm | FakeSplitIntoListParser() | second_llm.map() assert dumps(chain, pretty=True) == snapshot output = chain.invoke({"question": "What up"})