core: Fix some missing generic types

This commit is contained in:
cbornet 2025-07-03 14:27:22 +02:00
parent 2fdccd789c
commit 8372d41f70
30 changed files with 238 additions and 233 deletions

View File

@ -125,7 +125,7 @@ class LangSmithLoader(BaseLoader):
yield Document(content_str, metadata=metadata)
def _stringify(x: Union[str, dict]) -> str:
def _stringify(x: Union[str, dict[str, Any]]) -> str:
if isinstance(x, str):
return x
try:

View File

@ -202,13 +202,17 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
BaseMessage,
list[str],
tuple[str, Union[str, list[Union[str, dict[str, Any]]]]],
str,
dict[str, Any],
]
def _create_message_from_message_type(
message_type: str,
content: str,
content: Union[str, list[Union[str, dict[str, Any]]]],
name: Optional[str] = None,
tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None,
@ -218,13 +222,13 @@ def _create_message_from_message_type(
"""Create a message from a message type and content string.
Args:
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
content: (str) the content string.
name: (str) the name of the message. Default is None.
tool_call_id: (str) the tool call id. Default is None.
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
id: (str) the id of the message. Default is None.
additional_kwargs: (dict[str, Any]) additional keyword arguments.
message_type: the type of the message (e.g., "human", "ai", etc.).
content: the content string.
name: the name of the message. Default is None.
tool_call_id: the tool call id. Default is None.
tool_calls: the tool calls. Default is None.
id: the id of the message. Default is None.
**additional_kwargs: additional keyword arguments.
Returns:
a message of the appropriate type.
@ -1004,12 +1008,13 @@ def convert_to_openai_messages(
oai_messages: list = []
if is_single := isinstance(messages, (BaseMessage, dict, str)):
messages = [messages]
messages_ = (
[messages]
if (is_single := isinstance(messages, (BaseMessage, dict, str, tuple)))
else messages
)
messages = convert_to_messages(messages)
for i, message in enumerate(messages):
for i, message in enumerate(convert_to_messages(messages_)):
oai_msg: dict = {"role": _get_message_openai_role(message)}
tool_messages: list = []
content: Union[str, list[dict]]

View File

@ -225,7 +225,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@model_validator(mode="before")
@classmethod
def validate_schema(cls, values: dict) -> Any:
def validate_schema(cls, values: dict[str, Any]) -> Any:
"""Validate the pydantic schema.
Args:

View File

@ -3,9 +3,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Optional,
@ -51,9 +51,6 @@ from langchain_core.prompts.string import (
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Sequence
class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages.
@ -772,7 +769,7 @@ MessageLikeRepresentation = Union[
MessageLike,
tuple[
Union[str, type],
Union[str, list[dict], list[object]],
Union[str, Sequence[dict], Sequence[object]],
],
str,
dict[str, Any],
@ -1435,11 +1432,13 @@ def _convert_to_message_template(
f" Got: {message}"
)
raise ValueError(msg)
message = (message["role"], message["content"])
if len(message) != 2:
msg = f"Expected 2-tuple of (role, template), got {message}"
raise ValueError(msg)
message_type_str, template = message
message_type_str = message["role"]
template = message["content"]
else:
if len(message) != 2:
msg = f"Expected 2-tuple of (role, template), got {message}"
raise ValueError(msg)
message_type_str, template = message
if isinstance(message_type_str, str):
message_ = _create_template_from_message_type(
message_type_str, template, template_format=template_format

View File

@ -100,7 +100,7 @@ class EvaluatorCallbackHandler(BaseTracer):
)
else:
self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
self.futures: weakref.WeakSet[Future[None]] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished
self.project_name = project_name
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}

View File

@ -10,7 +10,7 @@ class DummyExampleSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> None:
self.example = example
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
return [input_variables]

View File

@ -276,7 +276,9 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
self.on_retriever_error_common()
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
def __deepcopy__(
self, memo: Union[dict[int, Any], None] = None
) -> "FakeCallbackHandler":
return self
@ -426,5 +428,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
self.on_text_common()
# Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
def __deepcopy__(
self, memo: Union[dict[int, Any], None] = None
) -> "FakeAsyncCallbackHandler":
return self

View File

@ -41,7 +41,7 @@ if TYPE_CHECKING:
@pytest.fixture
def messages() -> list:
def messages() -> list[BaseMessage]:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I am a test user."),
@ -49,14 +49,14 @@ def messages() -> list:
@pytest.fixture
def messages_2() -> list:
def messages_2() -> list[BaseMessage]:
return [
SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I not a test user."),
]
def test_batch_size(messages: list, messages_2: list) -> None:
def test_batch_size(messages: list[BaseMessage], messages_2: list[BaseMessage]) -> None:
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
@ -80,7 +80,9 @@ def test_batch_size(messages: list, messages_2: list) -> None:
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_async_batch_size(messages: list, messages_2: list) -> None:
async def test_async_batch_size(
messages: list[BaseMessage], messages_2: list[BaseMessage]
) -> None:
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
# The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1
@ -262,7 +264,7 @@ async def test_astream_implementation_uses_astream() -> None:
class FakeTracer(BaseTracer):
def __init__(self) -> None:
super().__init__()
self.traced_run_ids: list = []
self.traced_run_ids: list[uuid.UUID] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
@ -415,7 +417,7 @@ async def test_disable_streaming_no_streaming_model_async(
class FakeChatModelStartTracer(FakeTracer):
def __init__(self) -> None:
super().__init__()
self.messages: list = []
self.messages: list[list[list[BaseMessage]]] = []
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
_, messages = args

View File

@ -17,6 +17,7 @@ from langchain_core.messages import (
ToolMessage,
)
from langchain_core.messages.utils import (
MessageLikeRepresentation,
convert_to_messages,
convert_to_openai_messages,
count_tokens_approximately,
@ -153,7 +154,7 @@ def test_merge_messages_tool_messages() -> None:
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
],
)
def test_filter_message(filters: dict) -> None:
def test_filter_message(filters: dict[str, Any]) -> None:
messages = [
SystemMessage("foo", name="blah", id="1"),
HumanMessage("bar", name="blur", id="2"),
@ -673,7 +674,7 @@ class FakeTokenCountingModel(FakeChatModel):
def test_convert_to_messages() -> None:
message_like: list = [
message_like: list[MessageLikeRepresentation] = [
# BaseMessage
SystemMessage("1"),
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),
@ -1179,7 +1180,7 @@ def test_convert_to_openai_messages_mixed_content_types() -> None:
def test_convert_to_openai_messages_developer() -> None:
messages: list = [
messages: list[MessageLikeRepresentation] = [
SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}),
{"role": "developer", "content": "a"},
]

View File

@ -17,7 +17,7 @@ from langchain_core.output_parsers.openai_tools import (
)
from langchain_core.outputs import ChatGeneration
STREAMED_MESSAGES: list = [
STREAMED_MESSAGES = [
AIMessageChunk(content=""),
AIMessageChunk(
content="",
@ -331,7 +331,7 @@ for message in STREAMED_MESSAGES:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
EXPECTED_STREAMED_JSON = [
EXPECTED_STREAMED_JSON: list[dict[str, Any]] = [
{},
{"names": ["suz"]},
{"names": ["suzy"]},
@ -392,7 +392,7 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None))
expected: list = [[]] + [
expected: list[list[dict[str, Any]]] = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@ -404,7 +404,7 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [
expected: list[list[dict[str, Any]]] = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@ -416,7 +416,7 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None))
expected: list = [[]] + [
expected: list[list[dict[str, Any]]] = [[]] + [
[
{
"type": "NameCollector",
@ -435,7 +435,9 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
expected: list[list[dict[str, Any]]] = [[]] + [
[chunk] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected
@ -446,7 +448,9 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
expected: list[list[dict[str, Any]]] = [[]] + [
[chunk] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected

View File

@ -141,9 +141,7 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel
)
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
result = pydantic_parser.parse(DEF_RESULT)
assert result == DEF_EXPECTED_RESULT
@ -152,9 +150,7 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
pydantic_object=TestModel
)
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
with pytest.raises(
OutputParserException, match="Failed to parse TestModel from completion"

View File

@ -1,4 +1,4 @@
from typing import Union
from typing import Any, Union
import pytest
@ -19,14 +19,14 @@ from langchain_core.outputs import ChatGeneration
],
],
)
def test_msg_with_text(content: Union[str, list]) -> None:
def test_msg_with_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
expected = "foo"
actual = ChatGeneration(message=AIMessage(content=content)).text
assert actual == expected
@pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]])
def test_msg_no_text(content: Union[str, list]) -> None:
def test_msg_no_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
expected = ""
actual = ChatGeneration(message=AIMessage(content=content)).text
assert actual == expected

View File

@ -1,7 +1,7 @@
import re
import warnings
from pathlib import Path
from typing import Any, Union, cast
from typing import Any, Union
import pytest
from packaging import version
@ -121,11 +121,10 @@ def test_create_system_message_prompt_template_from_template_partial() -> None:
History:
{history}
"""
json_prompt_instructions: dict = {}
graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=graph_creator_content,
input_variables=["history"],
partial_variables={"instructions": json_prompt_instructions},
partial_variables={"instructions": {}},
)
assert graph_analyst_template.format(history="history") == SystemMessage(
content="\n Your instructions are:\n {}\n History:\n history\n "
@ -973,46 +972,43 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
("system", "You are an AI assistant named {name}."),
("system", [{"text": "You are an AI assistant named {name}."}]),
SystemMessagePromptTemplate.from_template("you are {foo}"),
cast(
"tuple",
(
"human",
[
"hello",
{"text": "What's in this image?"},
{"type": "text", "text": "What's in this image?"},
{
"type": "text",
"text": "What's in this image?",
"cache_control": {"type": "{foo}"},
(
"human",
[
"hello",
{"text": "What's in this image?"},
{"type": "text", "text": "What's in this image?"},
{
"type": "text",
"text": "What's in this image?",
"cache_control": {"type": "{foo}"},
},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{my_image}",
},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
},
{"type": "image_url", "image_url": "{my_other_image}"},
{
"type": "image_url",
"image_url": {
"url": "{my_other_image}",
"detail": "medium",
},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{my_image}",
},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
},
{"type": "image_url", "image_url": "{my_other_image}"},
{
"type": "image_url",
"image_url": {
"url": "{my_other_image}",
"detail": "medium",
},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
{"image_url": {"url": ""}},
],
),
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
{"image_url": {"url": ""}},
],
),
("placeholder", "{chat_history}"),
MessagesPlaceholder("more_history", optional=False),
@ -1179,7 +1175,7 @@ def test_chat_prompt_template_data_prompt_from_message(
cache_control_placeholder: str,
source_data_placeholder: str,
) -> None:
prompt: dict = {
prompt: dict[str, Any] = {
"type": "image",
"source_type": "base64",
"data": f"{source_data_placeholder}",

View File

@ -385,7 +385,7 @@ class AsIsSelector(BaseExampleSelector):
raise NotImplementedError
@override
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
return list(self.examples)
@ -480,11 +480,13 @@ class AsyncAsIsSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> Any:
raise NotImplementedError
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
raise NotImplementedError
@override
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
async def aselect_examples(
self, input_variables: dict[str, str]
) -> list[dict[str, str]]:
return list(self.examples)

View File

@ -15,7 +15,7 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
@contextmanager
def change_directory(dir_path: Path) -> Iterator:
def change_directory(dir_path: Path) -> Iterator[None]:
"""Change the working directory to the right folder."""
origin = Path().absolute()
try:

View File

@ -265,7 +265,7 @@ def test_prompt_from_template_with_partial_variables() -> None:
def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {foo} test."
input_variables: list = []
input_variables: list[str] = []
with pytest.raises(
ValueError,
match=re.escape("check for mismatched or missing input parameters from []"),
@ -509,7 +509,7 @@ Your variable again: {{ foo }}
def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test."
input_variables: list = []
input_variables: list[str] = []
with pytest.warns(UserWarning, match="Missing variables: {'foo'}"):
PromptTemplate(
input_variables=input_variables,

View File

@ -14,11 +14,15 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
) -> Union[BaseModel, dict]:
_: Any,
*,
schema: Union[dict[str, Any], type[BaseModel]],
value: Any = 42,
**_kwargs: Any,
) -> Union[BaseModel, dict[str, Any]]:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value)
params = cast("dict", schema)["parameters"]
params = cast("dict[str, Any]", schema)["parameters"]
return {k: 1 if k != "value" else value for k, v in params.items()}

View File

@ -6,8 +6,7 @@ from typing import Any
import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
@pytest.mark.asyncio
@ -97,7 +96,7 @@ def test_batch_concurrency() -> None:
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3
@ -129,7 +128,7 @@ def test_batch_as_completed_concurrency() -> None:
return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function)
runnable = RunnableLambda(tracked_function)
num_tasks = 10
max_concurrency = 3

View File

@ -26,7 +26,7 @@ from langchain_core.tracers.stdout import ConsoleCallbackHandler
def test_ensure_config() -> None:
run_id = str(uuid.uuid4())
arg: dict = {
arg: dict[str, Any] = {
"something": "else",
"metadata": {"foo": "bar"},
"configurable": {"baz": "qux"},
@ -147,7 +147,7 @@ async def test_merge_config_callbacks() -> None:
def test_config_arbitrary_keys() -> None:
base: RunnablePassthrough[Any] = RunnablePassthrough()
bound = base.with_config(my_custom_key="my custom value")
config = cast("RunnableBinding", bound).config
config = cast("RunnableBinding[Any, Any]", bound).config
assert config.get("my_custom_key") == "my custom value"

View File

@ -332,7 +332,8 @@ test_cases = [
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
cases: list[_TestCase],
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output
@ -344,7 +345,8 @@ def test_context_runnables(
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
cases: list[_TestCase],
) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output

View File

@ -34,7 +34,7 @@ from langchain_core.tools import BaseTool
@pytest.fixture
def llm() -> RunnableWithFallbacks:
def llm() -> RunnableWithFallbacks[Any, Any]:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
@ -42,7 +42,7 @@ def llm() -> RunnableWithFallbacks:
@pytest.fixture
def llm_multi() -> RunnableWithFallbacks:
def llm_multi() -> RunnableWithFallbacks[Any, Any]:
error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
@ -51,7 +51,7 @@ def llm_multi() -> RunnableWithFallbacks:
@pytest.fixture
def chain() -> Runnable:
def chain() -> Runnable[Any, str]:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])
@ -61,18 +61,18 @@ def chain() -> Runnable:
)
def _raise_error(_: dict) -> str:
def _raise_error(_: dict[str, Any]) -> str:
raise ValueError
def _dont_raise_error(inputs: dict) -> str:
def _dont_raise_error(inputs: dict[str, Any]) -> str:
if "exception" in inputs:
return "bar"
raise ValueError
@pytest.fixture
def chain_pass_exceptions() -> Runnable:
def chain_pass_exceptions() -> Runnable[Any, str]:
fallback = RunnableLambda(_dont_raise_error)
return {"text": RunnablePassthrough()} | RunnableLambda(
_raise_error
@ -80,13 +80,13 @@ def chain_pass_exceptions() -> Runnable:
@pytest.mark.parametrize(
"runnable",
"runnable_name",
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
)
def test_fallbacks(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
runnable_name: str, request: Any, snapshot: SnapshotAssertion
) -> None:
runnable = request.getfixturevalue(runnable)
runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"]
@ -94,17 +94,17 @@ def test_fallbacks(
@pytest.mark.parametrize(
"runnable",
"runnable_name",
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
)
async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None:
runnable = request.getfixturevalue(runnable)
async def test_fallbacks_async(runnable_name: str, request: Any) -> None:
runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar")
def _runnable(inputs: dict) -> str:
def _runnable(inputs: dict[str, Any]) -> str:
if inputs["text"] == "foo":
return "first"
if "exception" not in inputs:
@ -117,7 +117,7 @@ def _runnable(inputs: dict) -> str:
return "third"
def _assert_potential_error(actual: list, expected: list) -> None:
def _assert_potential_error(actual: list[Any], expected: list[Any]) -> None:
for x, y in zip(actual, expected):
if isinstance(x, Exception):
assert isinstance(y, type(x))
@ -260,17 +260,17 @@ async def test_abatch() -> None:
_assert_potential_error(actual, expected)
def _generate(_: Iterator) -> Iterator[str]:
def _generate(_: Iterator[Any]) -> Iterator[str]:
yield from "foo bar"
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
def _generate_immediate_error(_: Iterator[Any]) -> Iterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
def _generate_delayed_error(_: Iterator[Any]) -> Iterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@ -289,18 +289,18 @@ def test_fallbacks_stream() -> None:
list(runnable.stream({}))
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate(_: AsyncIterator[Any]) -> AsyncIterator[str]:
for c in "foo bar":
yield c
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_immediate_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
msg = "immmediate error"
raise ValueError(msg)
yield ""
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
async def _agenerate_delayed_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
yield ""
msg = "delayed error"
raise ValueError(msg)
@ -346,7 +346,7 @@ class FakeStructuredOutputModel(BaseChatModel):
@override
def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
) -> Runnable[Any, dict[str, int]]:
return RunnableLambda(lambda _: {"foo": self.foo})
@property

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any, Optional, Union
from packaging import version
from pydantic import BaseModel
@ -6,6 +6,7 @@ from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
from langchain_core.language_models import FakeListLLM
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser
@ -222,7 +223,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
str_parser = StrOutputParser()
xml_parser = XMLOutputParser()
def conditional_str_parser(value: str) -> Runnable:
def conditional_str_parser(value: str) -> Runnable[Union[BaseMessage, str], str]:
if value == "a":
return str_parser
return xml_parser
@ -528,7 +529,7 @@ def test_graph_mermaid_escape_node_label() -> None:
def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["foo", "bar"])
sequence: Runnable = (
sequence = (
PromptTemplate.from_template("Hello, {input}")
| {
"llm1": fake_llm,

View File

@ -35,7 +35,7 @@ def test_interfaces() -> None:
def _get_get_session_history(
*,
store: Optional[dict[str, Any]] = None,
store: Optional[dict[str, InMemoryChatMessageHistory]] = None,
) -> Callable[..., InMemoryChatMessageHistory]:
chat_history_store = store if store is not None else {}
@ -54,7 +54,7 @@ def test_input_messages() -> None:
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: dict = {}
store: dict[str, InMemoryChatMessageHistory] = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}}
@ -83,7 +83,7 @@ async def test_input_messages_async() -> None:
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: dict = {}
store: dict[str, InMemoryChatMessageHistory] = {}
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config = {"session_id": "1_async"}
@ -489,7 +489,7 @@ def test_get_output_schema() -> None:
)
output_type = with_history.get_output_schema()
expected_schema: dict = {
expected_schema: dict[str, Any] = {
"title": "RunnableWithChatHistoryOutput",
"type": "object",
}
@ -842,8 +842,7 @@ def test_get_output_messages_no_value_error() -> None:
lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
)
store: dict = {}
get_session_history = _get_get_session_history(store=store)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
@ -859,8 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {}
get_session_history = _get_get_session_history(store=store)
get_session_history = _get_get_session_history()
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")}

View File

@ -513,7 +513,7 @@ def test_passthrough_assign_schema() -> None:
prompt = PromptTemplate.from_template("{context} {question}")
fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]
seq_w_assign: Runnable = (
seq_w_assign = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| prompt
| fake_llm
@ -530,7 +530,7 @@ def test_passthrough_assign_schema() -> None:
"type": "string",
}
invalid_seq_w_assign: Runnable = (
invalid_seq_w_assign = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| fake_llm
)
@ -1011,7 +1011,7 @@ def test_passthrough_tap(mocker: MockerFixture) -> None:
fake = FakeRunnable()
mock = mocker.Mock()
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
assert seq.invoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [
@ -1078,7 +1078,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
fake = FakeRunnable()
mock = mocker.Mock()
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
assert await seq.ainvoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [
@ -1188,8 +1188,8 @@ def test_with_config(mocker: MockerFixture) -> None:
]
spy.reset_mock()
fake_1: Runnable = RunnablePassthrough()
fake_2: Runnable = RunnablePassthrough()
fake_1 = RunnablePassthrough[Any]()
fake_2 = RunnablePassthrough[Any]()
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
@ -1650,7 +1650,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat
chain = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@ -1683,7 +1683,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
)
chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat
chain = prompt | chat
mock_start = mocker.Mock()
mock_end = mocker.Mock()
@ -1787,7 +1787,7 @@ def test_prompt_with_chat_model(
)
chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@ -1893,7 +1893,7 @@ async def test_prompt_with_chat_model_async(
)
chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat
chain = prompt | chat
assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence)
@ -2007,7 +2007,7 @@ async def test_prompt_with_llm(
)
llm = FakeListLLM(responses=["foo", "bar"])
chain: Runnable = prompt | llm
chain = prompt | llm
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
@ -2204,7 +2204,7 @@ async def test_prompt_with_llm_parser(
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
parser = CommaSeparatedListOutputParser()
chain: Runnable = prompt | llm | parser
chain = prompt | llm | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
@ -2517,7 +2517,7 @@ async def test_stream_log_lists() -> None:
for i in range(4):
yield AddableDict(alist=[str(i)])
chain: Runnable = RunnableGenerator(list_producer)
chain = RunnableGenerator(list_producer)
stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"})
@ -2697,7 +2697,7 @@ def test_combining_sequences(
chain2 = cast("RunnableSequence", input_formatter | prompt2 | chat2 | parser2)
assert isinstance(chain, RunnableSequence)
assert isinstance(chain2, RunnableSequence)
assert chain2.first == input_formatter
assert chain2.middle == [prompt2, chat2]
assert chain2.last == parser2
@ -2705,6 +2705,7 @@ def test_combining_sequences(
combined_chain = cast("RunnableSequence", chain | chain2)
assert isinstance(combined_chain, RunnableSequence)
assert combined_chain.first == prompt
assert combined_chain.middle == [
chat,
@ -2869,13 +2870,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
@freeze_time("2023-01-01")
def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
chain1: Runnable = ChatPromptTemplate.from_template(
chain1 = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
chain2: Runnable = ChatPromptTemplate.from_template(
chain2 = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
router = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = {
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
@ -2913,13 +2914,13 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
async def test_router_runnable_async() -> None:
chain1: Runnable = ChatPromptTemplate.from_template(
chain1 = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
chain2: Runnable = ChatPromptTemplate.from_template(
chain2 = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
router = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = {
"key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]},
@ -2941,13 +2942,13 @@ async def test_router_runnable_async() -> None:
def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
math_chain: Runnable = ChatPromptTemplate.from_template(
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain: Runnable = ChatPromptTemplate.from_template(
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel(
input_map = RunnableParallel(
key=lambda x: x["key"],
input={"question": lambda x: x["question"]},
)
@ -2997,13 +2998,13 @@ def test_higher_order_lambda_runnable(
async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None:
math_chain: Runnable = ChatPromptTemplate.from_template(
math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"])
english_chain: Runnable = ChatPromptTemplate.from_template(
english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel(
input_map = RunnableParallel(
key=lambda x: x["key"],
input={"question": lambda x: x["question"]},
)
@ -3779,7 +3780,7 @@ async def test_deep_astream_assign() -> None:
def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = llm | StrOutputParser()
chain = llm | StrOutputParser()
stream = chain.transform(llm.stream("Hi there!"))
@ -3792,7 +3793,7 @@ def test_runnable_sequence_transform() -> None:
async def test_runnable_sequence_atransform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = llm | StrOutputParser()
chain = llm | StrOutputParser()
stream = chain.atransform(llm.astream("Hi there!"))
@ -3867,7 +3868,7 @@ def test_recursive_lambda() -> None:
def test_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
def _lambda(x: int) -> int:
if x == 1:
msg = "x is 1"
raise ValueError(msg)
@ -3932,7 +3933,7 @@ def test_retrying(mocker: MockerFixture) -> None:
async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]:
def _lambda(x: int) -> int:
if x == 1:
msg = "x is 1"
raise ValueError(msg)
@ -4046,7 +4047,7 @@ async def test_runnable_lambda_astream() -> None:
"""Test that astream works for both normal functions & those returning Runnable."""
# Wrapper to make a normal function async
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
def awrapper(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
async def afunc(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs)
@ -4140,8 +4141,8 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def _batch(
self,
inputs: list[str],
) -> list:
outputs: list[Any] = []
) -> list[Union[str, Exception]]:
outputs: list[Union[str, Exception]] = []
for value in inputs:
if value.startswith(self.fail_starts_with):
outputs.append(
@ -4281,8 +4282,8 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
async def _abatch(
self,
inputs: list[str],
) -> list:
outputs: list[Any] = []
) -> list[Union[str, Exception]]:
outputs: list[Union[str, Exception]] = []
for value in inputs:
if value.startswith(self.fail_starts_with):
outputs.append(
@ -5534,7 +5535,7 @@ def test_listeners() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict:
def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
return {**inputs, "key": "extra"}
shared_state = {}
@ -5564,7 +5565,7 @@ async def test_listeners_async() -> None:
from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict:
def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
return {**inputs, "key": "extra"}
shared_state = {}
@ -5577,7 +5578,7 @@ async def test_listeners_async() -> None:
def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs
chain: Runnable = (
chain = (
RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start)
.map()
@ -5647,7 +5648,7 @@ def test_pydantic_protected_namespaces() -> None:
with warnings.catch_warnings():
warnings.simplefilter("error")
class CustomChatModel(RunnableSerializable):
class CustomChatModel(RunnableSerializable[str, str]):
model_kwargs: dict[str, Any] = Field(default_factory=dict)

View File

@ -2,7 +2,7 @@
import asyncio
import sys
from collections.abc import AsyncIterator, Sequence
from collections.abc import AsyncIterator, Mapping, Sequence
from itertools import cycle
from typing import Any, cast
@ -44,12 +44,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
return cast("list[StreamEvent]", [{**event, "run_id": ""} for event in events])
async def _as_async_iterator(iterable: list) -> AsyncIterator:
"""Converts an iterable into an async iterator."""
for item in iterable:
yield item
async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]:
"""Collect the events and remove the run ids."""
materialized_events = [event async for event in events]
@ -59,7 +53,9 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEven
return events_
def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None:
def _assert_events_equal_allow_superset_metadata(
events: Sequence[Mapping[str, Any]], expected: Sequence[Mapping[str, Any]]
) -> None:
"""Assert that the events are equal."""
assert len(events) == len(expected)
for i, (event, expected_event) in enumerate(zip(events, expected)):
@ -1910,7 +1906,7 @@ async def test_runnable_with_message_history() -> None:
# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store: dict = {}
store: dict[str, list[BaseMessage]] = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history."""

View File

@ -74,12 +74,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
)
async def _as_async_iterator(iterable: list) -> AsyncIterator:
"""Converts an iterable into an async iterator."""
for item in iterable:
yield item
async def _collect_events(
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
) -> list[StreamEvent]:
@ -1866,7 +1860,7 @@ async def test_runnable_with_message_history() -> None:
# Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results.
store: dict = {}
store: dict[str, list[BaseMessage]] = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history."""

View File

@ -20,7 +20,7 @@ from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
def _get_posts(client: Client) -> list:
def _get_posts(client: Client) -> list[dict[str, Any]]:
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
posts = []
for call in mock_calls:
@ -274,7 +274,7 @@ class TestRunnableSequenceParallelTraceNesting:
def before(x: int) -> int:
return x
def after(x: dict) -> int:
def after(x: dict[str, Any]) -> int:
return x["chain_result"]
sequence = before | parallel | after

View File

@ -1,5 +1,5 @@
import sys
from typing import Callable
from typing import Any, Callable
import pytest
@ -22,7 +22,7 @@ from langchain_core.runnables.utils import (
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
],
)
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
def test_get_lambda_source(func: Callable[..., Any], expected_source: str) -> None:
"""Test get_lambda_source function."""
source = get_lambda_source(func)
assert source == expected_source

View File

@ -1,3 +1,5 @@
from typing import Any
import pytest
from langchain_tests.integration_tests.base_store import (
BaseStoreAsyncTests,
@ -8,7 +10,7 @@ from langchain_core.stores import InMemoryStore
# Check against standard tests
class TestSyncInMemoryStore(BaseStoreSyncTests):
class TestSyncInMemoryStore(BaseStoreSyncTests[Any]):
@pytest.fixture
def kv_store(self) -> InMemoryStore:
return InMemoryStore()

View File

@ -39,7 +39,6 @@ from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableLambda,
ensure_config,
@ -72,7 +71,7 @@ from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]:
tool_schema = tool.tool_call_schema
if isinstance(tool_schema, dict):
return tool_schema
@ -1402,15 +1401,15 @@ class _MockStructuredToolWithRawOutput(BaseTool):
self,
arg1: int,
arg2: bool, # noqa: FBT001
arg3: Optional[dict] = None,
) -> tuple[str, dict]:
arg3: Optional[dict[str, str]] = None,
) -> tuple[str, dict[str, Any]]:
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_artifact")
def _mock_structured_tool_with_artifact(
*, arg1: int, arg2: bool, arg3: Optional[dict] = None
) -> tuple[str, dict]:
*, arg1: int, arg2: bool, arg3: Optional[dict[str, str]] = None
) -> tuple[str, dict[str, Any]]:
"""A Structured Tool."""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@ -1419,7 +1418,7 @@ def _mock_structured_tool_with_artifact(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]
)
def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:
tool_call: dict = {
tool_call: dict[str, Any] = {
"name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123",
@ -1448,7 +1447,7 @@ def test_convert_from_runnable_dict() -> None:
def f(x: Args) -> str:
return str(x["a"] * max(x["b"]))
runnable: Runnable = RunnableLambda(f)
runnable = RunnableLambda(f)
as_tool = runnable.as_tool()
args_schema = as_tool.args_schema
assert args_schema is not None
@ -1480,14 +1479,14 @@ def test_convert_from_runnable_dict() -> None:
a: int = Field(..., description="Integer")
b: list[int] = Field(..., description="List of ints")
runnable = RunnableLambda(g)
as_tool = runnable.as_tool(GSchema)
as_tool.invoke({"a": 3, "b": [1, 2]})
runnable2 = RunnableLambda(g)
as_tool2 = runnable2.as_tool(GSchema)
as_tool2.invoke({"a": 3, "b": [1, 2]})
# Specify via arg_types:
runnable = RunnableLambda(g)
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool.invoke({"a": 3, "b": [1, 2]})
runnable3 = RunnableLambda(g)
as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool3.invoke({"a": 3, "b": [1, 2]})
assert result == "6"
# Test with config
@ -1496,9 +1495,9 @@ def test_convert_from_runnable_dict() -> None:
assert config["configurable"]["foo"] == "not-bar"
return str(x["a"] * max(x["b"]))
runnable = RunnableLambda(h)
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool.invoke(
runnable4 = RunnableLambda(h)
as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool4.invoke(
{"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}}
)
assert result == "6"
@ -1512,7 +1511,7 @@ def test_convert_from_runnable_other() -> None:
def g(x: str) -> str:
return x + "z"
runnable: Runnable = RunnableLambda(f) | g
runnable = RunnableLambda(f) | g
as_tool = runnable.as_tool()
args_schema = as_tool.args_schema
assert args_schema is None
@ -1527,10 +1526,10 @@ def test_convert_from_runnable_other() -> None:
assert config["configurable"]["foo"] == "not-bar"
return x + "a"
runnable = RunnableLambda(h)
as_tool = runnable.as_tool()
result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}})
assert result == "ba"
runnable2 = RunnableLambda(h)
as_tool2 = runnable2.as_tool()
result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}})
assert result2 == "ba"
@tool("foo", parse_docstring=True)
@ -1785,7 +1784,7 @@ def test_tool_inherited_injected_arg() -> None:
}
def _get_parametrized_tools() -> list:
def _get_parametrized_tools() -> list[Callable[..., Any]]:
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
"""my_tool."""
return some_tool
@ -1800,7 +1799,7 @@ def _get_parametrized_tools() -> list:
@pytest.mark.parametrize("tool_", _get_parametrized_tools())
def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None:
assert convert_to_openai_function(tool_) == {
"name": tool_.__name__,
"description": "my_tool.",
@ -2528,13 +2527,13 @@ def test_tool_decorator_description() -> None:
assert foo_args_jsons_schema.description == "JSON Schema."
assert (
cast("dict", foo_args_jsons_schema.tool_call_schema)["description"]
cast("dict[str, Any]", foo_args_jsons_schema.tool_call_schema)["description"]
== "JSON Schema."
)
assert foo_args_jsons_schema_with_description.description == "description"
assert (
cast("dict", foo_args_jsons_schema_with_description.tool_call_schema)[
cast("dict[str, Any]", foo_args_jsons_schema_with_description.tool_call_schema)[
"description"
]
== "description"