mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Update method and docstring
This commit is contained in:
parent
93bbf67afc
commit
9777c2801d
@ -217,9 +217,10 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def each(self) -> Runnable[List[Input], List[Output]]:
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
Wrap a Runnable to run it on each element of the input sequence.
|
||||
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||
by calling invoke() with each input.
|
||||
"""
|
||||
return RunnableEach(bound=self)
|
||||
|
||||
@ -1384,7 +1385,7 @@ class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def each(self) -> RunnableEach[Input, Output]: # type: ignore[override]
|
||||
def map(self) -> RunnableEach[Input, Output]: # type: ignore[override]
|
||||
return self
|
||||
|
||||
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||
|
@ -563,10 +563,9 @@
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"output_parsers",
|
||||
"list",
|
||||
"CommaSeparatedListOutputParser"
|
||||
"runnable",
|
||||
"test_runnable",
|
||||
"FakeSplitIntoListParser"
|
||||
],
|
||||
"kwargs": {}
|
||||
}
|
||||
|
@ -21,7 +21,6 @@ from langchain.prompts.chat import (
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
)
|
||||
from langchain.pydantic_v1 import ValidationError
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@ -29,7 +28,7 @@ from langchain.schema.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
@ -1090,6 +1089,24 @@ async def test_llm_with_fallbacks(
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a list of comma separated values, "
|
||||
"eg: `foo, bar, baz`"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text.strip().split(", ")
|
||||
|
||||
|
||||
def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
@ -1098,7 +1115,7 @@ def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
|
||||
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
|
||||
|
||||
chain = prompt | first_llm | CommaSeparatedListOutputParser() | second_llm.each()
|
||||
chain = prompt | first_llm | FakeSplitIntoListParser() | second_llm.map()
|
||||
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
output = chain.invoke({"question": "What up"})
|
||||
|
Loading…
Reference in New Issue
Block a user