Runnables: Add .map() method (#9445)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. These live is docs/extras
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17, @rlancemartin.
 -->
This commit is contained in:
Nuno Campos 2023-08-23 19:54:12 +01:00 committed by GitHub
commit 64a958c85d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 239 additions and 4 deletions

View File

@ -217,6 +217,13 @@ class Runnable(Generic[Input, Output], ABC):
""" """
return RunnableBinding(bound=self, kwargs=kwargs) return RunnableBinding(bound=self, kwargs=kwargs)
def map(self) -> Runnable[List[Input], List[Output]]:
"""
Return a new Runnable that maps a list of inputs to a list of outputs,
by calling invoke() with each input.
"""
return RunnableEach(bound=self)
def with_fallbacks( def with_fallbacks(
self, self,
fallbacks: Sequence[Runnable[Input, Output]], fallbacks: Sequence[Runnable[Input, Output]],
@ -1360,6 +1367,39 @@ class RunnableLambda(Runnable[Input, Output]):
return self._call_with_config(self.func, input, config) return self._call_with_config(self.func, input, config)
class RunnableEach(Serializable, Runnable[List[Input], List[Output]]):
"""
A runnable that delegates calls to another runnable
with each element of the input sequence.
"""
bound: Runnable[Input, Output]
class Config:
arbitrary_types_allowed = True
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_namespace(self) -> List[str]:
return self.__class__.__module__.split(".")[:-1]
def bind(self, **kwargs: Any) -> RunnableEach[Input, Output]:
return RunnableEach(bound=self.bound.bind(**kwargs))
def invoke(
self, input: List[Input], config: Optional[RunnableConfig] = None
) -> List[Output]:
return self.bound.batch(input, config)
async def ainvoke(
self, input: List[Input], config: Optional[RunnableConfig] = None, **kwargs: Any
) -> List[Output]:
return await self.bound.abatch(input, config, **kwargs)
class RunnableBinding(Serializable, Runnable[Input, Output]): class RunnableBinding(Serializable, Runnable[Input, Output]):
""" """
A runnable that delegates calls to another runnable with a set of kwargs. A runnable that delegates calls to another runnable with a set of kwargs.

View File

@ -54,7 +54,7 @@ if _PYDANTIC_MAJOR_VERSION == 1:
try: try:
from openapi_schema_pydantic import OpenAPI from openapi_schema_pydantic import OpenAPI
except ImportError: except ImportError:
OpenAPI = object OpenAPI = object # type: ignore
class OpenAPISpec(OpenAPI): class OpenAPISpec(OpenAPI):
"""OpenAPI Model that removes mis-formatted parts of the spec.""" """OpenAPI Model that removes mis-formatted parts of the spec."""

File diff suppressed because one or more lines are too long

View File

@ -27,7 +27,7 @@ from langchain.schema.messages import (
HumanMessage, HumanMessage,
SystemMessage, SystemMessage,
) )
from langchain.schema.output_parser import StrOutputParser from langchain.schema.output_parser import BaseOutputParser, StrOutputParser
from langchain.schema.retriever import BaseRetriever from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import ( from langchain.schema.runnable import (
RouterRunnable, RouterRunnable,
@ -171,11 +171,21 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None:
assert spy.call_args_list == [ assert spy.call_args_list == [
mocker.call( mocker.call(
"hello", "hello",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
), ),
mocker.call( mocker.call(
"wooorld", "wooorld",
dict(metadata={"key": "value"}, tags=[], callbacks=None, _locals={}), dict(
metadata={"key": "value"},
tags=[],
callbacks=None,
_locals={},
),
), ),
] ]
@ -1076,3 +1086,53 @@ async def test_llm_with_fallbacks(
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar") assert list(await runnable.ainvoke("hello")) == list("bar")
assert dumps(runnable, pretty=True) == snapshot assert dumps(runnable, pretty=True) == snapshot
class FakeSplitIntoListParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a comma-separated list."""
@property
def lc_serializable(self) -> bool:
return True
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")
def test_each_simple() -> None:
"""Test that each() works with a simple runnable."""
parser = FakeSplitIntoListParser()
assert parser.invoke("first item, second item") == ["first item", "second item"]
assert parser.map().invoke(["a, b", "c"]) == [["a", "b"], ["c"]]
assert parser.map().map().invoke([["a, b", "c"], ["c, e"]]) == [
[["a", "b"], ["c"]],
[["c", "e"]],
]
def test_each(snapshot: SnapshotAssertion) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
first_llm = FakeStreamingListLLM(responses=["first item, second item, third item"])
parser = FakeSplitIntoListParser()
second_llm = FakeStreamingListLLM(responses=["this", "is", "a", "test"])
chain = prompt | first_llm | parser | second_llm.map()
assert dumps(chain, pretty=True) == snapshot
output = chain.invoke({"question": "What up"})
assert output == ["this", "is", "a"]
assert (parser | second_llm.map()).invoke("first item, second item") == [
"test",
"this",
]