mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 21:50:25 +00:00
Use a less specific return type for | on Runnables (#11762)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
6c5bb1b2e1
commit
4321d192ea
@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
from langchain.schema.output_parser import NoOpOutputParser
|
||||
from langchain.schema.prompt_template import BasePromptTemplate
|
||||
from langchain.schema.runnable import RunnableParallel, RunnableSequence
|
||||
from langchain.schema.runnable import Runnable, RunnableParallel
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
|
||||
|
||||
@ -30,7 +30,7 @@ def create_sql_query_chain(
|
||||
db: SQLDatabase,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
k: int = 5,
|
||||
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
|
||||
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
|
||||
"""Create a chain that generates SQL queries.
|
||||
|
||||
Args:
|
||||
|
@ -242,7 +242,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Callable[[Iterator[Any]], Iterator[Other]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
) -> Runnable[Input, Other]:
|
||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||
|
||||
@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
Callable[[Iterator[Other]], Iterator[Any]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
) -> Runnable[Other, Output]:
|
||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||
|
||||
@ -1064,7 +1064,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
Callable[[Iterator[Any]], Iterator[Other]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Input, Other]:
|
||||
) -> Runnable[Input, Other]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=self.first,
|
||||
@ -1086,7 +1086,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
Callable[[Iterator[Other]], Iterator[Any]],
|
||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||
],
|
||||
) -> RunnableSequence[Other, Output]:
|
||||
) -> Runnable[Other, Output]:
|
||||
if isinstance(other, RunnableSequence):
|
||||
return RunnableSequence(
|
||||
first=other.first,
|
||||
|
@ -7,6 +7,7 @@ from langchain.prompts import PromptTemplate
|
||||
from langchain.schema.runnable import (
|
||||
GetLocalVar,
|
||||
PutLocalVar,
|
||||
Runnable,
|
||||
RunnablePassthrough,
|
||||
RunnableSequence,
|
||||
)
|
||||
@ -52,12 +53,12 @@ def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) ->
|
||||
|
||||
|
||||
def test_get_in_map() -> None:
|
||||
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||
|
||||
|
||||
def test_put_in_map() -> None:
|
||||
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||
with pytest.raises(KeyError):
|
||||
runnable.invoke("foo")
|
||||
|
||||
|
@ -1978,7 +1978,7 @@ def test_combining_sequences(
|
||||
lambda x: {"question": x[0] + x[1]}
|
||||
)
|
||||
|
||||
chain2 = input_formatter | prompt2 | chat2 | parser2
|
||||
chain2 = cast(RunnableSequence, input_formatter | prompt2 | chat2 | parser2)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain2.first == input_formatter
|
||||
@ -1987,7 +1987,7 @@ def test_combining_sequences(
|
||||
if sys.version_info >= (3, 9):
|
||||
assert dumps(chain2, pretty=True) == snapshot
|
||||
|
||||
combined_chain = chain | chain2
|
||||
combined_chain = cast(RunnableSequence, chain | chain2)
|
||||
|
||||
assert combined_chain.first == prompt
|
||||
assert combined_chain.middle == [
|
||||
@ -2972,7 +2972,7 @@ def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm_chain_with_fallbacks() -> RunnableSequence:
|
||||
def llm_chain_with_fallbacks() -> Runnable:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user