mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 11:09:07 +00:00
Runnables: Add .map() method (#9445)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
64a958c85d
@ -217,6 +217,13 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
"""
|
||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||
|
||||
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||
"""
|
||||
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||
by calling invoke() with each input.
|
||||
"""
|
||||
return RunnableEach(bound=self)
|
||||
|
||||
def with_fallbacks(
|
||||
self,
|
||||
fallbacks: Sequence[Runnable[Input, Output]],
|
||||
@ -1360,6 +1367,39 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
return self._call_with_config(self.func, input, config)
|
||||
|
||||
|
||||
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable
|
||||
with each element of the input sequence.
|
||||
"""
|
||||
|
||||
bound: Runnable[Input, Output]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_namespace(self) -> List[str]:
|
||||
return self.__class__.__module__.split(".")[:-1]
|
||||
|
||||
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||
|
||||
def invoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||
) -> List[Output]:
|
||||
return self.bound.batch(input, config)
|
||||
|
||||
async def ainvoke(
|
||||
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||
) -> List[Output]:
|
||||
return await self.bound.abatch(input, config, **kwargs)
|
||||
|
||||
|
||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||
"""
|
||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||
|
@ -54,7 +54,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
|
||||
try:
|
||||
from openapi_schema_pydantic import OpenAPI
|
||||
except ImportError:
|
||||
OpenAPI = object
|
||||
OpenAPI = object # type: ignore
|
||||
|
||||
class OpenAPISpec(OpenAPI):
|
||||
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||
|
File diff suppressed because one or more lines are too long
@ -27,7 +27,7 @@ from langchain.schema.messages import (
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output_parser import StrOutputParser
|
||||
from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
|
||||
from langchain.schema.retriever import BaseRetriever
|
||||
from langchain.schema.runnable import (
|
||||
RouterRunnable,
|
||||
@ -171,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
||||
assert spy.call_args_list == [
|
||||
mocker.call(
|
||||
"hello",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
),
|
||||
),
|
||||
mocker.call(
|
||||
"wooorld",
|
||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
||||
dict(
|
||||
metadata={"key": "value"},
|
||||
tags=[],
|
||||
callbacks=None,
|
||||
_locals={},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
@ -1076,3 +1086,53 @@ async def test_llm_with_fallbacks(
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
assert dumps(runnable, pretty=True) == snapshot
|
||||
|
||||
|
||||
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
||||
"""Parse the output of an LLM call to a comma-separated list."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
return (
|
||||
"Your response should be a list of comma separated values, "
|
||||
"eg: `foo, bar, baz`"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> List[str]:
|
||||
"""Parse the output of an LLM call."""
|
||||
return text.strip().split(", ")
|
||||
|
||||
|
||||
def test_each_simple() -> None:
|
||||
"""Test that each() works with a simple runnable."""
|
||||
parser = FakeSplitIntoListParser()
|
||||
assert parser.invoke("first item, second item") == ["first item", "second item"]
|
||||
assert parser.map().invoke(["a, b", "c"]) == [["a", "b"], ["c"]]
|
||||
assert parser.map().map().invoke([["a, b", "c"], ["c, e"]]) == [
|
||||
[["a", "b"], ["c"]],
|
||||
[["c", "e"]],
|
||||
]
|
||||
|
||||
|
||||
def test_each(snapshot: SnapshotAssertion) -> None:
|
||||
prompt = (
|
||||
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||
+ "{question}"
|
||||
)
|
||||
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
|
||||
parser = FakeSplitIntoListParser()
|
||||
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
|
||||
|
||||
chain = prompt | first_llm | parser | second_llm.map()
|
||||
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
output = chain.invoke({"question": "What up"})
|
||||
assert output == ["this", "is", "a"]
|
||||
|
||||
assert (parser | second_llm.map()).invoke("first item, second item") == [
|
||||
"test",
|
||||
"this",
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user