mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-30 00:04:19 +00:00
core: Fix some missing generic types
This commit is contained in:
parent
2fdccd789c
commit
8372d41f70
@ -125,7 +125,7 @@ class LangSmithLoader(BaseLoader):
|
||||
yield Document(content_str, metadata=metadata)
|
||||
|
||||
|
||||
def _stringify(x: Union[str, dict]) -> str:
|
||||
def _stringify(x: Union[str, dict[str, Any]]) -> str:
|
||||
if isinstance(x, str):
|
||||
return x
|
||||
try:
|
||||
|
@ -202,13 +202,17 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
|
||||
|
||||
|
||||
MessageLikeRepresentation = Union[
|
||||
BaseMessage, list[str], tuple[str, str], str, dict[str, Any]
|
||||
BaseMessage,
|
||||
list[str],
|
||||
tuple[str, Union[str, list[Union[str, dict[str, Any]]]]],
|
||||
str,
|
||||
dict[str, Any],
|
||||
]
|
||||
|
||||
|
||||
def _create_message_from_message_type(
|
||||
message_type: str,
|
||||
content: str,
|
||||
content: Union[str, list[Union[str, dict[str, Any]]]],
|
||||
name: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict[str, Any]]] = None,
|
||||
@ -218,13 +222,13 @@ def _create_message_from_message_type(
|
||||
"""Create a message from a message type and content string.
|
||||
|
||||
Args:
|
||||
message_type: (str) the type of the message (e.g., "human", "ai", etc.).
|
||||
content: (str) the content string.
|
||||
name: (str) the name of the message. Default is None.
|
||||
tool_call_id: (str) the tool call id. Default is None.
|
||||
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None.
|
||||
id: (str) the id of the message. Default is None.
|
||||
additional_kwargs: (dict[str, Any]) additional keyword arguments.
|
||||
message_type: the type of the message (e.g., "human", "ai", etc.).
|
||||
content: the content string.
|
||||
name: the name of the message. Default is None.
|
||||
tool_call_id: the tool call id. Default is None.
|
||||
tool_calls: the tool calls. Default is None.
|
||||
id: the id of the message. Default is None.
|
||||
**additional_kwargs: additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
a message of the appropriate type.
|
||||
@ -1004,12 +1008,13 @@ def convert_to_openai_messages(
|
||||
|
||||
oai_messages: list = []
|
||||
|
||||
if is_single := isinstance(messages, (BaseMessage, dict, str)):
|
||||
messages = [messages]
|
||||
messages_ = (
|
||||
[messages]
|
||||
if (is_single := isinstance(messages, (BaseMessage, dict, str, tuple)))
|
||||
else messages
|
||||
)
|
||||
|
||||
messages = convert_to_messages(messages)
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
for i, message in enumerate(convert_to_messages(messages_)):
|
||||
oai_msg: dict = {"role": _get_message_openai_role(message)}
|
||||
tool_messages: list = []
|
||||
content: Union[str, list[dict]]
|
||||
|
@ -225,7 +225,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_schema(cls, values: dict) -> Any:
|
||||
def validate_schema(cls, values: dict[str, Any]) -> Any:
|
||||
"""Validate the pydantic schema.
|
||||
|
||||
Args:
|
||||
|
@ -3,9 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Optional,
|
||||
@ -51,9 +51,6 @@ from langchain_core.prompts.string import (
|
||||
from langchain_core.utils import get_colored_text
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""Prompt template that assumes variable is already list of messages.
|
||||
@ -772,7 +769,7 @@ MessageLikeRepresentation = Union[
|
||||
MessageLike,
|
||||
tuple[
|
||||
Union[str, type],
|
||||
Union[str, list[dict], list[object]],
|
||||
Union[str, Sequence[dict], Sequence[object]],
|
||||
],
|
||||
str,
|
||||
dict[str, Any],
|
||||
@ -1435,11 +1432,13 @@ def _convert_to_message_template(
|
||||
f" Got: {message}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
message = (message["role"], message["content"])
|
||||
if len(message) != 2:
|
||||
msg = f"Expected 2-tuple of (role, template), got {message}"
|
||||
raise ValueError(msg)
|
||||
message_type_str, template = message
|
||||
message_type_str = message["role"]
|
||||
template = message["content"]
|
||||
else:
|
||||
if len(message) != 2:
|
||||
msg = f"Expected 2-tuple of (role, template), got {message}"
|
||||
raise ValueError(msg)
|
||||
message_type_str, template = message
|
||||
if isinstance(message_type_str, str):
|
||||
message_ = _create_template_from_message_type(
|
||||
message_type_str, template, template_format=template_format
|
||||
|
@ -100,7 +100,7 @@ class EvaluatorCallbackHandler(BaseTracer):
|
||||
)
|
||||
else:
|
||||
self.executor = None
|
||||
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
|
||||
self.futures: weakref.WeakSet[Future[None]] = weakref.WeakSet()
|
||||
self.skip_unfinished = skip_unfinished
|
||||
self.project_name = project_name
|
||||
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}
|
||||
|
@ -10,7 +10,7 @@ class DummyExampleSelector(BaseExampleSelector):
|
||||
def add_example(self, example: dict[str, str]) -> None:
|
||||
self.example = example
|
||||
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
|
||||
return [input_variables]
|
||||
|
||||
|
||||
|
@ -276,7 +276,9 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
self.on_retriever_error_common()
|
||||
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override]
|
||||
def __deepcopy__(
|
||||
self, memo: Union[dict[int, Any], None] = None
|
||||
) -> "FakeCallbackHandler":
|
||||
return self
|
||||
|
||||
|
||||
@ -426,5 +428,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
self.on_text_common()
|
||||
|
||||
# Overriding since BaseModel has __deepcopy__ method as well
|
||||
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override]
|
||||
def __deepcopy__(
|
||||
self, memo: Union[dict[int, Any], None] = None
|
||||
) -> "FakeAsyncCallbackHandler":
|
||||
return self
|
||||
|
@ -41,7 +41,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages() -> list:
|
||||
def messages() -> list[BaseMessage]:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I am a test user."),
|
||||
@ -49,14 +49,14 @@ def messages() -> list:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def messages_2() -> list:
|
||||
def messages_2() -> list[BaseMessage]:
|
||||
return [
|
||||
SystemMessage(content="You are a test user."),
|
||||
HumanMessage(content="Hello, I not a test user."),
|
||||
]
|
||||
|
||||
|
||||
def test_batch_size(messages: list, messages_2: list) -> None:
|
||||
def test_batch_size(messages: list[BaseMessage], messages_2: list[BaseMessage]) -> None:
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
@ -80,7 +80,9 @@ def test_batch_size(messages: list, messages_2: list) -> None:
|
||||
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
|
||||
|
||||
|
||||
async def test_async_batch_size(messages: list, messages_2: list) -> None:
|
||||
async def test_async_batch_size(
|
||||
messages: list[BaseMessage], messages_2: list[BaseMessage]
|
||||
) -> None:
|
||||
llm = FakeListChatModel(responses=[str(i) for i in range(100)])
|
||||
# The base endpoint doesn't support native batching,
|
||||
# so we expect batch_size to always be 1
|
||||
@ -262,7 +264,7 @@ async def test_astream_implementation_uses_astream() -> None:
|
||||
class FakeTracer(BaseTracer):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.traced_run_ids: list = []
|
||||
self.traced_run_ids: list[uuid.UUID] = []
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
@ -415,7 +417,7 @@ async def test_disable_streaming_no_streaming_model_async(
|
||||
class FakeChatModelStartTracer(FakeTracer):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.messages: list = []
|
||||
self.messages: list[list[list[BaseMessage]]] = []
|
||||
|
||||
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
|
||||
_, messages = args
|
||||
|
@ -17,6 +17,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.utils import (
|
||||
MessageLikeRepresentation,
|
||||
convert_to_messages,
|
||||
convert_to_openai_messages,
|
||||
count_tokens_approximately,
|
||||
@ -153,7 +154,7 @@ def test_merge_messages_tool_messages() -> None:
|
||||
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
|
||||
],
|
||||
)
|
||||
def test_filter_message(filters: dict) -> None:
|
||||
def test_filter_message(filters: dict[str, Any]) -> None:
|
||||
messages = [
|
||||
SystemMessage("foo", name="blah", id="1"),
|
||||
HumanMessage("bar", name="blur", id="2"),
|
||||
@ -673,7 +674,7 @@ class FakeTokenCountingModel(FakeChatModel):
|
||||
|
||||
|
||||
def test_convert_to_messages() -> None:
|
||||
message_like: list = [
|
||||
message_like: list[MessageLikeRepresentation] = [
|
||||
# BaseMessage
|
||||
SystemMessage("1"),
|
||||
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),
|
||||
@ -1179,7 +1180,7 @@ def test_convert_to_openai_messages_mixed_content_types() -> None:
|
||||
|
||||
|
||||
def test_convert_to_openai_messages_developer() -> None:
|
||||
messages: list = [
|
||||
messages: list[MessageLikeRepresentation] = [
|
||||
SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}),
|
||||
{"role": "developer", "content": "a"},
|
||||
]
|
||||
|
@ -17,7 +17,7 @@ from langchain_core.output_parsers.openai_tools import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration
|
||||
|
||||
STREAMED_MESSAGES: list = [
|
||||
STREAMED_MESSAGES = [
|
||||
AIMessageChunk(content=""),
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
@ -331,7 +331,7 @@ for message in STREAMED_MESSAGES:
|
||||
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
|
||||
|
||||
|
||||
EXPECTED_STREAMED_JSON = [
|
||||
EXPECTED_STREAMED_JSON: list[dict[str, Any]] = [
|
||||
{},
|
||||
{"names": ["suz"]},
|
||||
{"names": ["suzy"]},
|
||||
@ -392,7 +392,7 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
expected: list[list[dict[str, Any]]] = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
@ -404,7 +404,7 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
|
||||
chain = input_iter | JsonOutputToolsParser()
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [
|
||||
expected: list[list[dict[str, Any]]] = [[]] + [
|
||||
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
@ -416,7 +416,7 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
|
||||
chain = input_iter | JsonOutputToolsParser(return_id=True)
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [
|
||||
expected: list[list[dict[str, Any]]] = [[]] + [
|
||||
[
|
||||
{
|
||||
"type": "NameCollector",
|
||||
@ -435,7 +435,9 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = list(chain.stream(None))
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
expected: list[list[dict[str, Any]]] = [[]] + [
|
||||
[chunk] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@ -446,7 +448,9 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
|
||||
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
|
||||
|
||||
actual = [p async for p in chain.astream(None)]
|
||||
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON]
|
||||
expected: list[list[dict[str, Any]]] = [[]] + [
|
||||
[chunk] for chunk in EXPECTED_STREAMED_JSON
|
||||
]
|
||||
assert actual == expected
|
||||
|
||||
|
||||
|
@ -141,9 +141,7 @@ DEF_EXPECTED_RESULT = TestModel(
|
||||
|
||||
def test_pydantic_output_parser() -> None:
|
||||
"""Test PydanticOutputParser."""
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
|
||||
|
||||
result = pydantic_parser.parse(DEF_RESULT)
|
||||
assert result == DEF_EXPECTED_RESULT
|
||||
@ -152,9 +150,7 @@ def test_pydantic_output_parser() -> None:
|
||||
|
||||
def test_pydantic_output_parser_fail() -> None:
|
||||
"""Test PydanticOutputParser where completion result fails schema validation."""
|
||||
pydantic_parser: PydanticOutputParser = PydanticOutputParser(
|
||||
pydantic_object=TestModel
|
||||
)
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
|
||||
|
||||
with pytest.raises(
|
||||
OutputParserException, match="Failed to parse TestModel from completion"
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Union
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
|
||||
@ -19,14 +19,14 @@ from langchain_core.outputs import ChatGeneration
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_msg_with_text(content: Union[str, list]) -> None:
|
||||
def test_msg_with_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
|
||||
expected = "foo"
|
||||
actual = ChatGeneration(message=AIMessage(content=content)).text
|
||||
assert actual == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]])
|
||||
def test_msg_no_text(content: Union[str, list]) -> None:
|
||||
def test_msg_no_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
|
||||
expected = ""
|
||||
actual = ChatGeneration(message=AIMessage(content=content)).text
|
||||
assert actual == expected
|
||||
|
@ -1,7 +1,7 @@
|
||||
import re
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Union, cast
|
||||
from typing import Any, Union
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
@ -121,11 +121,10 @@ def test_create_system_message_prompt_template_from_template_partial() -> None:
|
||||
History:
|
||||
{history}
|
||||
"""
|
||||
json_prompt_instructions: dict = {}
|
||||
graph_analyst_template = SystemMessagePromptTemplate.from_template(
|
||||
template=graph_creator_content,
|
||||
input_variables=["history"],
|
||||
partial_variables={"instructions": json_prompt_instructions},
|
||||
partial_variables={"instructions": {}},
|
||||
)
|
||||
assert graph_analyst_template.format(history="history") == SystemMessage(
|
||||
content="\n Your instructions are:\n {}\n History:\n history\n "
|
||||
@ -973,46 +972,43 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
("system", [{"text": "You are an AI assistant named {name}."}]),
|
||||
SystemMessagePromptTemplate.from_template("you are {foo}"),
|
||||
cast(
|
||||
"tuple",
|
||||
(
|
||||
"human",
|
||||
[
|
||||
"hello",
|
||||
{"text": "What's in this image?"},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
"cache_control": {"type": "{foo}"},
|
||||
(
|
||||
"human",
|
||||
[
|
||||
"hello",
|
||||
{"text": "What's in this image?"},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "What's in this image?",
|
||||
"cache_control": {"type": "{foo}"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/jpeg;base64,{my_image}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
|
||||
},
|
||||
{"type": "image_url", "image_url": "{my_other_image}"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "{my_other_image}",
|
||||
"detail": "medium",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/jpeg;base64,{my_image}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
|
||||
},
|
||||
{"type": "image_url", "image_url": "{my_other_image}"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "{my_other_image}",
|
||||
"detail": "medium",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,foobar"},
|
||||
},
|
||||
{"image_url": {"url": "data:image/jpeg;base64,foobar"}},
|
||||
],
|
||||
),
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://www.langchain.com/image.png"},
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,foobar"},
|
||||
},
|
||||
{"image_url": {"url": "data:image/jpeg;base64,foobar"}},
|
||||
],
|
||||
),
|
||||
("placeholder", "{chat_history}"),
|
||||
MessagesPlaceholder("more_history", optional=False),
|
||||
@ -1179,7 +1175,7 @@ def test_chat_prompt_template_data_prompt_from_message(
|
||||
cache_control_placeholder: str,
|
||||
source_data_placeholder: str,
|
||||
) -> None:
|
||||
prompt: dict = {
|
||||
prompt: dict[str, Any] = {
|
||||
"type": "image",
|
||||
"source_type": "base64",
|
||||
"data": f"{source_data_placeholder}",
|
||||
|
@ -385,7 +385,7 @@ class AsIsSelector(BaseExampleSelector):
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
|
||||
return list(self.examples)
|
||||
|
||||
|
||||
@ -480,11 +480,13 @@ class AsyncAsIsSelector(BaseExampleSelector):
|
||||
def add_example(self, example: dict[str, str]) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]:
|
||||
async def aselect_examples(
|
||||
self, input_variables: dict[str, str]
|
||||
) -> list[dict[str, str]]:
|
||||
return list(self.examples)
|
||||
|
||||
|
||||
|
@ -15,7 +15,7 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def change_directory(dir_path: Path) -> Iterator:
|
||||
def change_directory(dir_path: Path) -> Iterator[None]:
|
||||
"""Change the working directory to the right folder."""
|
||||
origin = Path().absolute()
|
||||
try:
|
||||
|
@ -265,7 +265,7 @@ def test_prompt_from_template_with_partial_variables() -> None:
|
||||
def test_prompt_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {foo} test."
|
||||
input_variables: list = []
|
||||
input_variables: list[str] = []
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape("check for mismatched or missing input parameters from []"),
|
||||
@ -509,7 +509,7 @@ Your variable again: {{ foo }}
|
||||
def test_prompt_jinja2_missing_input_variables() -> None:
|
||||
"""Test error is raised when input variables are not provided."""
|
||||
template = "This is a {{ foo }} test."
|
||||
input_variables: list = []
|
||||
input_variables: list[str] = []
|
||||
with pytest.warns(UserWarning, match="Missing variables: {'foo'}"):
|
||||
PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
|
@ -14,11 +14,15 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any
|
||||
) -> Union[BaseModel, dict]:
|
||||
_: Any,
|
||||
*,
|
||||
schema: Union[dict[str, Any], type[BaseModel]],
|
||||
value: Any = 42,
|
||||
**_kwargs: Any,
|
||||
) -> Union[BaseModel, dict[str, Any]]:
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
return schema(name="yo", value=value)
|
||||
params = cast("dict", schema)["parameters"]
|
||||
params = cast("dict[str, Any]", schema)["parameters"]
|
||||
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||
|
||||
|
||||
|
@ -6,8 +6,7 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_core.runnables import RunnableConfig, RunnableLambda
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -97,7 +96,7 @@ def test_batch_concurrency() -> None:
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
@ -129,7 +128,7 @@ def test_batch_as_completed_concurrency() -> None:
|
||||
|
||||
return f"Completed {x}"
|
||||
|
||||
runnable: Runnable = RunnableLambda(tracked_function)
|
||||
runnable = RunnableLambda(tracked_function)
|
||||
num_tasks = 10
|
||||
max_concurrency = 3
|
||||
|
||||
|
@ -26,7 +26,7 @@ from langchain_core.tracers.stdout import ConsoleCallbackHandler
|
||||
|
||||
def test_ensure_config() -> None:
|
||||
run_id = str(uuid.uuid4())
|
||||
arg: dict = {
|
||||
arg: dict[str, Any] = {
|
||||
"something": "else",
|
||||
"metadata": {"foo": "bar"},
|
||||
"configurable": {"baz": "qux"},
|
||||
@ -147,7 +147,7 @@ async def test_merge_config_callbacks() -> None:
|
||||
def test_config_arbitrary_keys() -> None:
|
||||
base: RunnablePassthrough[Any] = RunnablePassthrough()
|
||||
bound = base.with_config(my_custom_key="my custom value")
|
||||
config = cast("RunnableBinding", bound).config
|
||||
config = cast("RunnableBinding[Any, Any]", bound).config
|
||||
|
||||
assert config.get("my_custom_key") == "my custom value"
|
||||
|
||||
|
@ -332,7 +332,8 @@ test_cases = [
|
||||
|
||||
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||
def test_context_runnables(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
|
||||
cases: list[_TestCase],
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert runnable.invoke(cases[0].input) == cases[0].output
|
||||
@ -344,7 +345,8 @@ def test_context_runnables(
|
||||
|
||||
@pytest.mark.parametrize(("runnable", "cases"), test_cases)
|
||||
async def test_context_runnables_async(
|
||||
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase]
|
||||
runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
|
||||
cases: list[_TestCase],
|
||||
) -> None:
|
||||
runnable = runnable if isinstance(runnable, Runnable) else runnable()
|
||||
assert await runnable.ainvoke(cases[1].input) == cases[1].output
|
||||
|
@ -34,7 +34,7 @@ from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm() -> RunnableWithFallbacks:
|
||||
def llm() -> RunnableWithFallbacks[Any, Any]:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
@ -42,7 +42,7 @@ def llm() -> RunnableWithFallbacks:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_multi() -> RunnableWithFallbacks:
|
||||
def llm_multi() -> RunnableWithFallbacks[Any, Any]:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
error_llm_2 = FakeListLLM(responses=["baz"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
@ -51,7 +51,7 @@ def llm_multi() -> RunnableWithFallbacks:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chain() -> Runnable:
|
||||
def chain() -> Runnable[Any, str]:
|
||||
error_llm = FakeListLLM(responses=["foo"], i=1)
|
||||
pass_llm = FakeListLLM(responses=["bar"])
|
||||
|
||||
@ -61,18 +61,18 @@ def chain() -> Runnable:
|
||||
)
|
||||
|
||||
|
||||
def _raise_error(_: dict) -> str:
|
||||
def _raise_error(_: dict[str, Any]) -> str:
|
||||
raise ValueError
|
||||
|
||||
|
||||
def _dont_raise_error(inputs: dict) -> str:
|
||||
def _dont_raise_error(inputs: dict[str, Any]) -> str:
|
||||
if "exception" in inputs:
|
||||
return "bar"
|
||||
raise ValueError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chain_pass_exceptions() -> Runnable:
|
||||
def chain_pass_exceptions() -> Runnable[Any, str]:
|
||||
fallback = RunnableLambda(_dont_raise_error)
|
||||
return {"text": RunnablePassthrough()} | RunnableLambda(
|
||||
_raise_error
|
||||
@ -80,13 +80,13 @@ def chain_pass_exceptions() -> Runnable:
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
"runnable_name",
|
||||
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
|
||||
)
|
||||
def test_fallbacks(
|
||||
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion
|
||||
runnable_name: str, request: Any, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
|
||||
assert runnable.invoke("hello") == "bar"
|
||||
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(runnable.stream("hello")) == ["bar"]
|
||||
@ -94,17 +94,17 @@ def test_fallbacks(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"runnable",
|
||||
"runnable_name",
|
||||
["llm", "llm_multi", "chain", "chain_pass_exceptions"],
|
||||
)
|
||||
async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None:
|
||||
runnable = request.getfixturevalue(runnable)
|
||||
async def test_fallbacks_async(runnable_name: str, request: Any) -> None:
|
||||
runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
|
||||
assert await runnable.ainvoke("hello") == "bar"
|
||||
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
|
||||
assert list(await runnable.ainvoke("hello")) == list("bar")
|
||||
|
||||
|
||||
def _runnable(inputs: dict) -> str:
|
||||
def _runnable(inputs: dict[str, Any]) -> str:
|
||||
if inputs["text"] == "foo":
|
||||
return "first"
|
||||
if "exception" not in inputs:
|
||||
@ -117,7 +117,7 @@ def _runnable(inputs: dict) -> str:
|
||||
return "third"
|
||||
|
||||
|
||||
def _assert_potential_error(actual: list, expected: list) -> None:
|
||||
def _assert_potential_error(actual: list[Any], expected: list[Any]) -> None:
|
||||
for x, y in zip(actual, expected):
|
||||
if isinstance(x, Exception):
|
||||
assert isinstance(y, type(x))
|
||||
@ -260,17 +260,17 @@ async def test_abatch() -> None:
|
||||
_assert_potential_error(actual, expected)
|
||||
|
||||
|
||||
def _generate(_: Iterator) -> Iterator[str]:
|
||||
def _generate(_: Iterator[Any]) -> Iterator[str]:
|
||||
yield from "foo bar"
|
||||
|
||||
|
||||
def _generate_immediate_error(_: Iterator) -> Iterator[str]:
|
||||
def _generate_immediate_error(_: Iterator[Any]) -> Iterator[str]:
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
def _generate_delayed_error(_: Iterator) -> Iterator[str]:
|
||||
def _generate_delayed_error(_: Iterator[Any]) -> Iterator[str]:
|
||||
yield ""
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
@ -289,18 +289,18 @@ def test_fallbacks_stream() -> None:
|
||||
list(runnable.stream({}))
|
||||
|
||||
|
||||
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate(_: AsyncIterator[Any]) -> AsyncIterator[str]:
|
||||
for c in "foo bar":
|
||||
yield c
|
||||
|
||||
|
||||
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate_immediate_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
|
||||
msg = "immmediate error"
|
||||
raise ValueError(msg)
|
||||
yield ""
|
||||
|
||||
|
||||
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]:
|
||||
async def _agenerate_delayed_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
|
||||
yield ""
|
||||
msg = "delayed error"
|
||||
raise ValueError(msg)
|
||||
@ -346,7 +346,7 @@ class FakeStructuredOutputModel(BaseChatModel):
|
||||
@override
|
||||
def with_structured_output(
|
||||
self, schema: Union[dict, type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]:
|
||||
) -> Runnable[Any, dict[str, int]]:
|
||||
return RunnableLambda(lambda _: {"foo": self.foo})
|
||||
|
||||
@property
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
@ -6,6 +6,7 @@ from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.language_models import FakeListLLM
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
|
||||
from langchain_core.output_parsers.string import StrOutputParser
|
||||
from langchain_core.output_parsers.xml import XMLOutputParser
|
||||
@ -222,7 +223,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
|
||||
str_parser = StrOutputParser()
|
||||
xml_parser = XMLOutputParser()
|
||||
|
||||
def conditional_str_parser(value: str) -> Runnable:
|
||||
def conditional_str_parser(value: str) -> Runnable[Union[BaseMessage, str], str]:
|
||||
if value == "a":
|
||||
return str_parser
|
||||
return xml_parser
|
||||
@ -528,7 +529,7 @@ def test_graph_mermaid_escape_node_label() -> None:
|
||||
|
||||
def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
|
||||
fake_llm = FakeListLLM(responses=["foo", "bar"])
|
||||
sequence: Runnable = (
|
||||
sequence = (
|
||||
PromptTemplate.from_template("Hello, {input}")
|
||||
| {
|
||||
"llm1": fake_llm,
|
||||
|
@ -35,7 +35,7 @@ def test_interfaces() -> None:
|
||||
|
||||
def _get_get_session_history(
|
||||
*,
|
||||
store: Optional[dict[str, Any]] = None,
|
||||
store: Optional[dict[str, InMemoryChatMessageHistory]] = None,
|
||||
) -> Callable[..., InMemoryChatMessageHistory]:
|
||||
chat_history_store = store if store is not None else {}
|
||||
|
||||
@ -54,7 +54,7 @@ def test_input_messages() -> None:
|
||||
lambda messages: "you said: "
|
||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||
)
|
||||
store: dict = {}
|
||||
store: dict[str, InMemoryChatMessageHistory] = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config: RunnableConfig = {"configurable": {"session_id": "1"}}
|
||||
@ -83,7 +83,7 @@ async def test_input_messages_async() -> None:
|
||||
lambda messages: "you said: "
|
||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||
)
|
||||
store: dict = {}
|
||||
store: dict[str, InMemoryChatMessageHistory] = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config = {"session_id": "1_async"}
|
||||
@ -489,7 +489,7 @@ def test_get_output_schema() -> None:
|
||||
)
|
||||
output_type = with_history.get_output_schema()
|
||||
|
||||
expected_schema: dict = {
|
||||
expected_schema: dict[str, Any] = {
|
||||
"title": "RunnableWithChatHistoryOutput",
|
||||
"type": "object",
|
||||
}
|
||||
@ -842,8 +842,7 @@ def test_get_output_messages_no_value_error() -> None:
|
||||
lambda messages: "you said: "
|
||||
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
|
||||
)
|
||||
store: dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history)
|
||||
config: RunnableConfig = {
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
@ -859,8 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
|
||||
def test_get_output_messages_with_value_error() -> None:
|
||||
illegal_bool_message = False
|
||||
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
|
||||
store: dict = {}
|
||||
get_session_history = _get_get_session_history(store=store)
|
||||
get_session_history = _get_get_session_history()
|
||||
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
|
||||
config: RunnableConfig = {
|
||||
"configurable": {"session_id": "1", "message_history": get_session_history("1")}
|
||||
|
@ -513,7 +513,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
prompt = PromptTemplate.from_template("{context} {question}")
|
||||
fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]
|
||||
|
||||
seq_w_assign: Runnable = (
|
||||
seq_w_assign = (
|
||||
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||
| prompt
|
||||
| fake_llm
|
||||
@ -530,7 +530,7 @@ def test_passthrough_assign_schema() -> None:
|
||||
"type": "string",
|
||||
}
|
||||
|
||||
invalid_seq_w_assign: Runnable = (
|
||||
invalid_seq_w_assign = (
|
||||
RunnablePassthrough.assign(context=itemgetter("question") | retriever)
|
||||
| fake_llm
|
||||
)
|
||||
@ -1011,7 +1011,7 @@ def test_passthrough_tap(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
mock = mocker.Mock()
|
||||
|
||||
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
|
||||
seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
|
||||
|
||||
assert seq.invoke("hello", my_kwarg="value") == 5
|
||||
assert mock.call_args_list == [
|
||||
@ -1078,7 +1078,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
|
||||
fake = FakeRunnable()
|
||||
mock = mocker.Mock()
|
||||
|
||||
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)
|
||||
seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
|
||||
|
||||
assert await seq.ainvoke("hello", my_kwarg="value") == 5
|
||||
assert mock.call_args_list == [
|
||||
@ -1188,8 +1188,8 @@ def test_with_config(mocker: MockerFixture) -> None:
|
||||
]
|
||||
spy.reset_mock()
|
||||
|
||||
fake_1: Runnable = RunnablePassthrough()
|
||||
fake_2: Runnable = RunnablePassthrough()
|
||||
fake_1 = RunnablePassthrough[Any]()
|
||||
fake_2 = RunnablePassthrough[Any]()
|
||||
spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
|
||||
|
||||
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
|
||||
@ -1650,7 +1650,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain: Runnable = prompt | chat
|
||||
chain = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
@ -1683,7 +1683,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain: Runnable = prompt | chat
|
||||
chain = prompt | chat
|
||||
|
||||
mock_start = mocker.Mock()
|
||||
mock_end = mocker.Mock()
|
||||
@ -1787,7 +1787,7 @@ def test_prompt_with_chat_model(
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain: Runnable = prompt | chat
|
||||
chain = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
@ -1893,7 +1893,7 @@ async def test_prompt_with_chat_model_async(
|
||||
)
|
||||
chat = FakeListChatModel(responses=["foo"])
|
||||
|
||||
chain: Runnable = prompt | chat
|
||||
chain = prompt | chat
|
||||
|
||||
assert repr(chain) == snapshot
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
@ -2007,7 +2007,7 @@ async def test_prompt_with_llm(
|
||||
)
|
||||
llm = FakeListLLM(responses=["foo", "bar"])
|
||||
|
||||
chain: Runnable = prompt | llm
|
||||
chain = prompt | llm
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
@ -2204,7 +2204,7 @@ async def test_prompt_with_llm_parser(
|
||||
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
|
||||
parser = CommaSeparatedListOutputParser()
|
||||
|
||||
chain: Runnable = prompt | llm | parser
|
||||
chain = prompt | llm | parser
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert chain.first == prompt
|
||||
@ -2517,7 +2517,7 @@ async def test_stream_log_lists() -> None:
|
||||
for i in range(4):
|
||||
yield AddableDict(alist=[str(i)])
|
||||
|
||||
chain: Runnable = RunnableGenerator(list_producer)
|
||||
chain = RunnableGenerator(list_producer)
|
||||
|
||||
stream_log = [
|
||||
part async for part in chain.astream_log({"question": "What is your name?"})
|
||||
@ -2697,7 +2697,7 @@ def test_combining_sequences(
|
||||
|
||||
chain2 = cast("RunnableSequence", input_formatter | prompt2 | chat2 | parser2)
|
||||
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
assert isinstance(chain2, RunnableSequence)
|
||||
assert chain2.first == input_formatter
|
||||
assert chain2.middle == [prompt2, chat2]
|
||||
assert chain2.last == parser2
|
||||
@ -2705,6 +2705,7 @@ def test_combining_sequences(
|
||||
|
||||
combined_chain = cast("RunnableSequence", chain | chain2)
|
||||
|
||||
assert isinstance(combined_chain, RunnableSequence)
|
||||
assert combined_chain.first == prompt
|
||||
assert combined_chain.middle == [
|
||||
chat,
|
||||
@ -2869,13 +2870,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
|
||||
chain1: Runnable = ChatPromptTemplate.from_template(
|
||||
chain1 = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
chain2: Runnable = ChatPromptTemplate.from_template(
|
||||
chain2 = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
|
||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||
chain: Runnable = {
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
@ -2913,13 +2914,13 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
|
||||
|
||||
|
||||
async def test_router_runnable_async() -> None:
|
||||
chain1: Runnable = ChatPromptTemplate.from_template(
|
||||
chain1 = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
chain2: Runnable = ChatPromptTemplate.from_template(
|
||||
chain2 = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
router: Runnable = RouterRunnable({"math": chain1, "english": chain2})
|
||||
router = RouterRunnable({"math": chain1, "english": chain2})
|
||||
chain: Runnable = {
|
||||
"key": lambda x: x["key"],
|
||||
"input": {"question": lambda x: x["question"]},
|
||||
@ -2941,13 +2942,13 @@ async def test_router_runnable_async() -> None:
|
||||
def test_higher_order_lambda_runnable(
|
||||
mocker: MockerFixture, snapshot: SnapshotAssertion
|
||||
) -> None:
|
||||
math_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
math_chain = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
english_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
english_chain = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
input_map: Runnable = RunnableParallel(
|
||||
input_map = RunnableParallel(
|
||||
key=lambda x: x["key"],
|
||||
input={"question": lambda x: x["question"]},
|
||||
)
|
||||
@ -2997,13 +2998,13 @@ def test_higher_order_lambda_runnable(
|
||||
|
||||
|
||||
async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None:
|
||||
math_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
math_chain = ChatPromptTemplate.from_template(
|
||||
"You are a math genius. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["4"])
|
||||
english_chain: Runnable = ChatPromptTemplate.from_template(
|
||||
english_chain = ChatPromptTemplate.from_template(
|
||||
"You are an english major. Answer the question: {question}"
|
||||
) | FakeListLLM(responses=["2"])
|
||||
input_map: Runnable = RunnableParallel(
|
||||
input_map = RunnableParallel(
|
||||
key=lambda x: x["key"],
|
||||
input={"question": lambda x: x["question"]},
|
||||
)
|
||||
@ -3779,7 +3780,7 @@ async def test_deep_astream_assign() -> None:
|
||||
def test_runnable_sequence_transform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain: Runnable = llm | StrOutputParser()
|
||||
chain = llm | StrOutputParser()
|
||||
|
||||
stream = chain.transform(llm.stream("Hi there!"))
|
||||
|
||||
@ -3792,7 +3793,7 @@ def test_runnable_sequence_transform() -> None:
|
||||
async def test_runnable_sequence_atransform() -> None:
|
||||
llm = FakeStreamingListLLM(responses=["foo-lish"])
|
||||
|
||||
chain: Runnable = llm | StrOutputParser()
|
||||
chain = llm | StrOutputParser()
|
||||
|
||||
stream = chain.atransform(llm.astream("Hi there!"))
|
||||
|
||||
@ -3867,7 +3868,7 @@ def test_recursive_lambda() -> None:
|
||||
|
||||
|
||||
def test_retrying(mocker: MockerFixture) -> None:
|
||||
def _lambda(x: int) -> Union[int, Runnable]:
|
||||
def _lambda(x: int) -> int:
|
||||
if x == 1:
|
||||
msg = "x is 1"
|
||||
raise ValueError(msg)
|
||||
@ -3932,7 +3933,7 @@ def test_retrying(mocker: MockerFixture) -> None:
|
||||
|
||||
|
||||
async def test_async_retrying(mocker: MockerFixture) -> None:
|
||||
def _lambda(x: int) -> Union[int, Runnable]:
|
||||
def _lambda(x: int) -> int:
|
||||
if x == 1:
|
||||
msg = "x is 1"
|
||||
raise ValueError(msg)
|
||||
@ -4046,7 +4047,7 @@ async def test_runnable_lambda_astream() -> None:
|
||||
"""Test that astream works for both normal functions & those returning Runnable."""
|
||||
|
||||
# Wrapper to make a normal function async
|
||||
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]:
|
||||
def awrapper(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
|
||||
async def afunc(*args: Any, **kwargs: Any) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -4140,8 +4141,8 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
def _batch(
|
||||
self,
|
||||
inputs: list[str],
|
||||
) -> list:
|
||||
outputs: list[Any] = []
|
||||
) -> list[Union[str, Exception]]:
|
||||
outputs: list[Union[str, Exception]] = []
|
||||
for value in inputs:
|
||||
if value.startswith(self.fail_starts_with):
|
||||
outputs.append(
|
||||
@ -4281,8 +4282,8 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
|
||||
async def _abatch(
|
||||
self,
|
||||
inputs: list[str],
|
||||
) -> list:
|
||||
outputs: list[Any] = []
|
||||
) -> list[Union[str, Exception]]:
|
||||
outputs: list[Union[str, Exception]] = []
|
||||
for value in inputs:
|
||||
if value.startswith(self.fail_starts_with):
|
||||
outputs.append(
|
||||
@ -5534,7 +5535,7 @@ def test_listeners() -> None:
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
def fake_chain(inputs: dict) -> dict:
|
||||
def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
|
||||
return {**inputs, "key": "extra"}
|
||||
|
||||
shared_state = {}
|
||||
@ -5564,7 +5565,7 @@ async def test_listeners_async() -> None:
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
def fake_chain(inputs: dict) -> dict:
|
||||
def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
|
||||
return {**inputs, "key": "extra"}
|
||||
|
||||
shared_state = {}
|
||||
@ -5577,7 +5578,7 @@ async def test_listeners_async() -> None:
|
||||
def on_end(run: Run) -> None:
|
||||
shared_state[run.id]["outputs"] = run.inputs
|
||||
|
||||
chain: Runnable = (
|
||||
chain = (
|
||||
RunnableLambda(fake_chain)
|
||||
.with_listeners(on_end=on_end, on_start=on_start)
|
||||
.map()
|
||||
@ -5647,7 +5648,7 @@ def test_pydantic_protected_namespaces() -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
|
||||
class CustomChatModel(RunnableSerializable):
|
||||
class CustomChatModel(RunnableSerializable[str, str]):
|
||||
model_kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from collections.abc import AsyncIterator, Sequence
|
||||
from collections.abc import AsyncIterator, Mapping, Sequence
|
||||
from itertools import cycle
|
||||
from typing import Any, cast
|
||||
|
||||
@ -44,12 +44,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||
return cast("list[StreamEvent]", [{**event, "run_id": ""} for event in events])
|
||||
|
||||
|
||||
async def _as_async_iterator(iterable: list) -> AsyncIterator:
|
||||
"""Converts an iterable into an async iterator."""
|
||||
for item in iterable:
|
||||
yield item
|
||||
|
||||
|
||||
async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]:
|
||||
"""Collect the events and remove the run ids."""
|
||||
materialized_events = [event async for event in events]
|
||||
@ -59,7 +53,9 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEven
|
||||
return events_
|
||||
|
||||
|
||||
def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None:
|
||||
def _assert_events_equal_allow_superset_metadata(
|
||||
events: Sequence[Mapping[str, Any]], expected: Sequence[Mapping[str, Any]]
|
||||
) -> None:
|
||||
"""Assert that the events are equal."""
|
||||
assert len(events) == len(expected)
|
||||
for i, (event, expected_event) in enumerate(zip(events, expected)):
|
||||
@ -1910,7 +1906,7 @@ async def test_runnable_with_message_history() -> None:
|
||||
|
||||
# Here we use a global variable to store the chat message history.
|
||||
# This will make it easier to inspect it to see the underlying results.
|
||||
store: dict = {}
|
||||
store: dict[str, list[BaseMessage]] = {}
|
||||
|
||||
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
||||
"""Get a chat message history."""
|
||||
|
@ -74,12 +74,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
|
||||
)
|
||||
|
||||
|
||||
async def _as_async_iterator(iterable: list) -> AsyncIterator:
|
||||
"""Converts an iterable into an async iterator."""
|
||||
for item in iterable:
|
||||
yield item
|
||||
|
||||
|
||||
async def _collect_events(
|
||||
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
|
||||
) -> list[StreamEvent]:
|
||||
@ -1866,7 +1860,7 @@ async def test_runnable_with_message_history() -> None:
|
||||
|
||||
# Here we use a global variable to store the chat message history.
|
||||
# This will make it easier to inspect it to see the underlying results.
|
||||
store: dict = {}
|
||||
store: dict[str, list[BaseMessage]] = {}
|
||||
|
||||
def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
|
||||
"""Get a chat message history."""
|
||||
|
@ -20,7 +20,7 @@ from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list:
|
||||
def _get_posts(client: Client) -> list[dict[str, Any]]:
|
||||
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
|
||||
posts = []
|
||||
for call in mock_calls:
|
||||
@ -274,7 +274,7 @@ class TestRunnableSequenceParallelTraceNesting:
|
||||
def before(x: int) -> int:
|
||||
return x
|
||||
|
||||
def after(x: dict) -> int:
|
||||
def after(x: dict[str, Any]) -> int:
|
||||
return x["chain_result"]
|
||||
|
||||
sequence = before | parallel | after
|
||||
|
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
from typing import Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@ -22,7 +22,7 @@ from langchain_core.runnables.utils import (
|
||||
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
|
||||
],
|
||||
)
|
||||
def test_get_lambda_source(func: Callable, expected_source: str) -> None:
|
||||
def test_get_lambda_source(func: Callable[..., Any], expected_source: str) -> None:
|
||||
"""Test get_lambda_source function."""
|
||||
source = get_lambda_source(func)
|
||||
assert source == expected_source
|
||||
|
@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_tests.integration_tests.base_store import (
|
||||
BaseStoreAsyncTests,
|
||||
@ -8,7 +10,7 @@ from langchain_core.stores import InMemoryStore
|
||||
|
||||
|
||||
# Check against standard tests
|
||||
class TestSyncInMemoryStore(BaseStoreSyncTests):
|
||||
class TestSyncInMemoryStore(BaseStoreSyncTests[Any]):
|
||||
@pytest.fixture
|
||||
def kv_store(self) -> InMemoryStore:
|
||||
return InMemoryStore()
|
||||
|
@ -39,7 +39,6 @@ from langchain_core.messages import ToolCall, ToolMessage
|
||||
from langchain_core.messages.tool import ToolOutputMixin
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import (
|
||||
Runnable,
|
||||
RunnableConfig,
|
||||
RunnableLambda,
|
||||
ensure_config,
|
||||
@ -72,7 +71,7 @@ from tests.unit_tests.fake.callbacks import FakeCallbackHandler
|
||||
from tests.unit_tests.pydantic_utils import _schema
|
||||
|
||||
|
||||
def _get_tool_call_json_schema(tool: BaseTool) -> dict:
|
||||
def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]:
|
||||
tool_schema = tool.tool_call_schema
|
||||
if isinstance(tool_schema, dict):
|
||||
return tool_schema
|
||||
@ -1402,15 +1401,15 @@ class _MockStructuredToolWithRawOutput(BaseTool):
|
||||
self,
|
||||
arg1: int,
|
||||
arg2: bool, # noqa: FBT001
|
||||
arg3: Optional[dict] = None,
|
||||
) -> tuple[str, dict]:
|
||||
arg3: Optional[dict[str, str]] = None,
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
|
||||
@tool("structured_api", response_format="content_and_artifact")
|
||||
def _mock_structured_tool_with_artifact(
|
||||
*, arg1: int, arg2: bool, arg3: Optional[dict] = None
|
||||
) -> tuple[str, dict]:
|
||||
*, arg1: int, arg2: bool, arg3: Optional[dict[str, str]] = None
|
||||
) -> tuple[str, dict[str, Any]]:
|
||||
"""A Structured Tool."""
|
||||
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
|
||||
|
||||
@ -1419,7 +1418,7 @@ def _mock_structured_tool_with_artifact(
|
||||
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]
|
||||
)
|
||||
def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:
|
||||
tool_call: dict = {
|
||||
tool_call: dict[str, Any] = {
|
||||
"name": "structured_api",
|
||||
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
|
||||
"id": "123",
|
||||
@ -1448,7 +1447,7 @@ def test_convert_from_runnable_dict() -> None:
|
||||
def f(x: Args) -> str:
|
||||
return str(x["a"] * max(x["b"]))
|
||||
|
||||
runnable: Runnable = RunnableLambda(f)
|
||||
runnable = RunnableLambda(f)
|
||||
as_tool = runnable.as_tool()
|
||||
args_schema = as_tool.args_schema
|
||||
assert args_schema is not None
|
||||
@ -1480,14 +1479,14 @@ def test_convert_from_runnable_dict() -> None:
|
||||
a: int = Field(..., description="Integer")
|
||||
b: list[int] = Field(..., description="List of ints")
|
||||
|
||||
runnable = RunnableLambda(g)
|
||||
as_tool = runnable.as_tool(GSchema)
|
||||
as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
runnable2 = RunnableLambda(g)
|
||||
as_tool2 = runnable2.as_tool(GSchema)
|
||||
as_tool2.invoke({"a": 3, "b": [1, 2]})
|
||||
|
||||
# Specify via arg_types:
|
||||
runnable = RunnableLambda(g)
|
||||
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
result = as_tool.invoke({"a": 3, "b": [1, 2]})
|
||||
runnable3 = RunnableLambda(g)
|
||||
as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
result = as_tool3.invoke({"a": 3, "b": [1, 2]})
|
||||
assert result == "6"
|
||||
|
||||
# Test with config
|
||||
@ -1496,9 +1495,9 @@ def test_convert_from_runnable_dict() -> None:
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
return str(x["a"] * max(x["b"]))
|
||||
|
||||
runnable = RunnableLambda(h)
|
||||
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
result = as_tool.invoke(
|
||||
runnable4 = RunnableLambda(h)
|
||||
as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]})
|
||||
result = as_tool4.invoke(
|
||||
{"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}}
|
||||
)
|
||||
assert result == "6"
|
||||
@ -1512,7 +1511,7 @@ def test_convert_from_runnable_other() -> None:
|
||||
def g(x: str) -> str:
|
||||
return x + "z"
|
||||
|
||||
runnable: Runnable = RunnableLambda(f) | g
|
||||
runnable = RunnableLambda(f) | g
|
||||
as_tool = runnable.as_tool()
|
||||
args_schema = as_tool.args_schema
|
||||
assert args_schema is None
|
||||
@ -1527,10 +1526,10 @@ def test_convert_from_runnable_other() -> None:
|
||||
assert config["configurable"]["foo"] == "not-bar"
|
||||
return x + "a"
|
||||
|
||||
runnable = RunnableLambda(h)
|
||||
as_tool = runnable.as_tool()
|
||||
result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}})
|
||||
assert result == "ba"
|
||||
runnable2 = RunnableLambda(h)
|
||||
as_tool2 = runnable2.as_tool()
|
||||
result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}})
|
||||
assert result2 == "ba"
|
||||
|
||||
|
||||
@tool("foo", parse_docstring=True)
|
||||
@ -1785,7 +1784,7 @@ def test_tool_inherited_injected_arg() -> None:
|
||||
}
|
||||
|
||||
|
||||
def _get_parametrized_tools() -> list:
|
||||
def _get_parametrized_tools() -> list[Callable[..., Any]]:
|
||||
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
|
||||
"""my_tool."""
|
||||
return some_tool
|
||||
@ -1800,7 +1799,7 @@ def _get_parametrized_tools() -> list:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tool_", _get_parametrized_tools())
|
||||
def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
|
||||
def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None:
|
||||
assert convert_to_openai_function(tool_) == {
|
||||
"name": tool_.__name__,
|
||||
"description": "my_tool.",
|
||||
@ -2528,13 +2527,13 @@ def test_tool_decorator_description() -> None:
|
||||
|
||||
assert foo_args_jsons_schema.description == "JSON Schema."
|
||||
assert (
|
||||
cast("dict", foo_args_jsons_schema.tool_call_schema)["description"]
|
||||
cast("dict[str, Any]", foo_args_jsons_schema.tool_call_schema)["description"]
|
||||
== "JSON Schema."
|
||||
)
|
||||
|
||||
assert foo_args_jsons_schema_with_description.description == "description"
|
||||
assert (
|
||||
cast("dict", foo_args_jsons_schema_with_description.tool_call_schema)[
|
||||
cast("dict[str, Any]", foo_args_jsons_schema_with_description.tool_call_schema)[
|
||||
"description"
|
||||
]
|
||||
== "description"
|
||||
|
Loading…
Reference in New Issue
Block a user