mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-07 13:40:46 +00:00
Use a less specific return type for | on Runnables (#11762)
<!-- Thank you for contributing to LangChain! Replace this entire comment with: - **Description:** a description of the change, - **Issue:** the issue # it fixes (if applicable), - **Dependencies:** any dependencies required for this change, - **Tag maintainer:** for a quicker response, tag the relevant maintainer (see below), - **Twitter handle:** we announce bigger features on Twitter. If your PR gets announced, and you'd like a mention, we'll gladly shout you out! Please make sure your PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/extras` directory. If no one reviews your PR within a few days, please @-mention one of @baskaryan, @eyurtsev, @hwchase17. --> --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
6c5bb1b2e1
commit
4321d192ea
@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
|
|||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.schema.output_parser import NoOpOutputParser
|
from langchain.schema.output_parser import NoOpOutputParser
|
||||||
from langchain.schema.prompt_template import BasePromptTemplate
|
from langchain.schema.prompt_template import BasePromptTemplate
|
||||||
from langchain.schema.runnable import RunnableParallel, RunnableSequence
|
from langchain.schema.runnable import Runnable, RunnableParallel
|
||||||
from langchain.utilities.sql_database import SQLDatabase
|
from langchain.utilities.sql_database import SQLDatabase
|
||||||
|
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ def create_sql_query_chain(
|
|||||||
db: SQLDatabase,
|
db: SQLDatabase,
|
||||||
prompt: Optional[BasePromptTemplate] = None,
|
prompt: Optional[BasePromptTemplate] = None,
|
||||||
k: int = 5,
|
k: int = 5,
|
||||||
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
|
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
|
||||||
"""Create a chain that generates SQL queries.
|
"""Create a chain that generates SQL queries.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -242,7 +242,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Callable[[Iterator[Any]], Iterator[Other]],
|
Callable[[Iterator[Any]], Iterator[Other]],
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Input, Other]:
|
) -> Runnable[Input, Other]:
|
||||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||||
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
return RunnableSequence(first=self, last=coerce_to_runnable(other))
|
||||||
|
|
||||||
@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
Callable[[Iterator[Other]], Iterator[Any]],
|
Callable[[Iterator[Other]], Iterator[Any]],
|
||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Other, Output]:
|
) -> Runnable[Other, Output]:
|
||||||
"""Compose this runnable with another object to create a RunnableSequence."""
|
"""Compose this runnable with another object to create a RunnableSequence."""
|
||||||
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
return RunnableSequence(first=coerce_to_runnable(other), last=self)
|
||||||
|
|
||||||
@ -1064,7 +1064,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
Callable[[Iterator[Any]], Iterator[Other]],
|
Callable[[Iterator[Any]], Iterator[Other]],
|
||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Input, Other]:
|
) -> Runnable[Input, Other]:
|
||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=self.first,
|
first=self.first,
|
||||||
@ -1086,7 +1086,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
|||||||
Callable[[Iterator[Other]], Iterator[Any]],
|
Callable[[Iterator[Other]], Iterator[Any]],
|
||||||
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSequence[Other, Output]:
|
) -> Runnable[Other, Output]:
|
||||||
if isinstance(other, RunnableSequence):
|
if isinstance(other, RunnableSequence):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
first=other.first,
|
first=other.first,
|
||||||
|
@ -7,6 +7,7 @@ from langchain.prompts import PromptTemplate
|
|||||||
from langchain.schema.runnable import (
|
from langchain.schema.runnable import (
|
||||||
GetLocalVar,
|
GetLocalVar,
|
||||||
PutLocalVar,
|
PutLocalVar,
|
||||||
|
Runnable,
|
||||||
RunnablePassthrough,
|
RunnablePassthrough,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
)
|
)
|
||||||
@ -52,12 +53,12 @@ def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) ->
|
|||||||
|
|
||||||
|
|
||||||
def test_get_in_map() -> None:
|
def test_get_in_map() -> None:
|
||||||
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
|
||||||
assert runnable.invoke("foo") == {"bar": "foo"}
|
assert runnable.invoke("foo") == {"bar": "foo"}
|
||||||
|
|
||||||
|
|
||||||
def test_put_in_map() -> None:
|
def test_put_in_map() -> None:
|
||||||
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
runnable.invoke("foo")
|
runnable.invoke("foo")
|
||||||
|
|
||||||
|
@ -1978,7 +1978,7 @@ def test_combining_sequences(
|
|||||||
lambda x: {"question": x[0] + x[1]}
|
lambda x: {"question": x[0] + x[1]}
|
||||||
)
|
)
|
||||||
|
|
||||||
chain2 = input_formatter | prompt2 | chat2 | parser2
|
chain2 = cast(RunnableSequence, input_formatter | prompt2 | chat2 | parser2)
|
||||||
|
|
||||||
assert isinstance(chain, RunnableSequence)
|
assert isinstance(chain, RunnableSequence)
|
||||||
assert chain2.first == input_formatter
|
assert chain2.first == input_formatter
|
||||||
@ -1987,7 +1987,7 @@ def test_combining_sequences(
|
|||||||
if sys.version_info >= (3, 9):
|
if sys.version_info >= (3, 9):
|
||||||
assert dumps(chain2, pretty=True) == snapshot
|
assert dumps(chain2, pretty=True) == snapshot
|
||||||
|
|
||||||
combined_chain = chain | chain2
|
combined_chain = cast(RunnableSequence, chain | chain2)
|
||||||
|
|
||||||
assert combined_chain.first == prompt
|
assert combined_chain.first == prompt
|
||||||
assert combined_chain.middle == [
|
assert combined_chain.middle == [
|
||||||
@ -2972,7 +2972,7 @@ def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def llm_chain_with_fallbacks() -> RunnableSequence:
|
def llm_chain_with_fallbacks() -> Runnable:
|
||||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||||
pass_llm = FakeListLLM(responses=["bar"])
|
pass_llm = FakeListLLM(responses=["bar"])
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user