mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-29 11:09:07 +00:00
Runnables: Add .map() method (#9445)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. These live is docs/extras directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17, @rlancemartin. -->
This commit is contained in:
commit
64a958c85d
@ -217,6 +217,13 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
"""
|
"""
|
||||||
return RunnableBinding(bound=self, kwargs=kwargs)
|
return RunnableBinding(bound=self, kwargs=kwargs)
|
||||||
|
|
||||||
|
def map(self) -> Runnable[List[Input], List[Output]]:
|
||||||
|
"""
|
||||||
|
Return a new Runnable that maps a list of inputs to a list of outputs,
|
||||||
|
by calling invoke() with each input.
|
||||||
|
"""
|
||||||
|
return RunnableEach(bound=self)
|
||||||
|
|
||||||
def with_fallbacks(
|
def with_fallbacks(
|
||||||
self,
|
self,
|
||||||
fallbacks: Sequence[Runnable[Input, Output]],
|
fallbacks: Sequence[Runnable[Input, Output]],
|
||||||
@ -1360,6 +1367,39 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
return self._call_with_config(self.func, input, config)
|
return self._call_with_config(self.func, input, config)
|
||||||
|
|
||||||
|
|
||||||
|
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
|
||||||
|
"""
|
||||||
|
A runnable that delegates calls to another runnable
|
||||||
|
with each element of the input sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
bound: Runnable[Input, Output]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_namespace(self) -> List[str]:
|
||||||
|
return self.__class__.__module__.split(".")[:-1]
|
||||||
|
|
||||||
|
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
|
||||||
|
return RunnableEach(bound=self.bound.bind(**kwargs))
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self, input: List[Input], config: Optional[RunnableConfig] = None
|
||||||
|
) -> List[Output]:
|
||||||
|
return self.bound.batch(input, config)
|
||||||
|
|
||||||
|
async def ainvoke(
|
||||||
|
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
|
||||||
|
) -> List[Output]:
|
||||||
|
return await self.bound.abatch(input, config, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
class RunnableBinding(Serializable, Runnable[Input, Output]):
|
||||||
"""
|
"""
|
||||||
A runnable that delegates calls to another runnable with a set of kwargs.
|
A runnable that delegates calls to another runnable with a set of kwargs.
|
||||||
|
@ -54,7 +54,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
|
|||||||
try:
|
try:
|
||||||
from openapi_schema_pydantic import OpenAPI
|
from openapi_schema_pydantic import OpenAPI
|
||||||
except ImportError:
|
except ImportError:
|
||||||
OpenAPI = object
|
OpenAPI = object # type: ignore
|
||||||
|
|
||||||
class OpenAPISpec(OpenAPI):
|
class OpenAPISpec(OpenAPI):
|
||||||
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
"""OpenAPI Model that removes mis-formatted parts of the spec."""
|
||||||
|
File diff suppressed because one or more lines are too long
@ -27,7 +27,7 @@ from langchain.schema.messages import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain.schema.output_parser import StrOutputParser
|
from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
|
||||||
from langchain.schema.retriever import BaseRetriever
|
from langchain.schema.retriever import BaseRetriever
|
||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
RouterRunnable,
|
RouterRunnable,
|
||||||
@ -171,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
|
|||||||
assert spy.call_args_list == [
|
assert spy.call_args_list == [
|
||||||
mocker.call(
|
mocker.call(
|
||||||
"hello",
|
"hello",
|
||||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=[],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
),
|
||||||
),
|
),
|
||||||
mocker.call(
|
mocker.call(
|
||||||
"wooorld",
|
"wooorld",
|
||||||
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}),
|
dict(
|
||||||
|
metadata={"key": "value"},
|
||||||
|
tags=[],
|
||||||
|
callbacks=None,
|
||||||
|
_locals={},
|
||||||
|
),
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1076,3 +1086,53 @@ async def test_llm_with_fallbacks(
|
|||||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||||
assert dumps(runnable, pretty=True) == snapshot
|
assert dumps(runnable, pretty=True) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
|
||||||
|
"""Parse the output of an LLM call to a comma-separated list."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_serializable(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_format_instructions(self) -> str:
|
||||||
|
return (
|
||||||
|
"Your response should be a list of comma separated values, "
|
||||||
|
"eg: `foo, bar, baz`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def parse(self, text: str) -> List[str]:
|
||||||
|
"""Parse the output of an LLM call."""
|
||||||
|
return text.strip().split(", ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_each_simple() -> None:
|
||||||
|
"""Test that each() works with a simple runnable."""
|
||||||
|
parser = FakeSplitIntoListParser()
|
||||||
|
assert parser.invoke("first item, second item") == ["first item", "second item"]
|
||||||
|
assert parser.map().invoke(["a, b", "c"]) == [["a", "b"], ["c"]]
|
||||||
|
assert parser.map().map().invoke([["a, b", "c"], ["c, e"]]) == [
|
||||||
|
[["a", "b"], ["c"]],
|
||||||
|
[["c", "e"]],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_each(snapshot: SnapshotAssertion) -> None:
|
||||||
|
prompt = (
|
||||||
|
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
|
||||||
|
+ "{question}"
|
||||||
|
)
|
||||||
|
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
|
||||||
|
parser = FakeSplitIntoListParser()
|
||||||
|
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
|
||||||
|
|
||||||
|
chain = prompt | first_llm | parser | second_llm.map()
|
||||||
|
|
||||||
|
assert dumps(chain, pretty=True) == snapshot
|
||||||
|
output = chain.invoke({"question": "What up"})
|
||||||
|
assert output == ["this", "is", "a"]
|
||||||
|
|
||||||
|
assert (parser | second_llm.map()).invoke("first item, second item") == [
|
||||||
|
"test",
|
||||||
|
"this",
|
||||||
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user