From aab2e421697a4730ee6cfddea3066b8e03c31781 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Fri, 31 Jan 2025 16:06:34 +0100 Subject: [PATCH] core[patch]: Use Blockbuster to detect blocking calls in asyncio during tests (#29043) This PR uses the [blockbuster](https://github.com/cbornet/blockbuster) library in langchain-core to detect blocking calls made in the asyncio event loop during unit tests. Avoiding blocking calls is hard as these can be deeply buried in the code or made in 3rd party libraries. Blockbuster makes it easier to detect them by raising an exception when a call is made to a known blocking function (eg: `time.sleep`). Adding blockbuster allowed to find a blocking call in `aconfig_with_context` (it ends up calling `get_function_nonlocals` which loads function code). **Dependencies:** - blockbuster (test) **Twitter handle:** cbornet_ --- libs/core/poetry.lock | 30 +- libs/core/pyproject.toml | 1 + libs/core/tests/unit_tests/conftest.py | 29 +- .../unit_tests/fake/test_fake_chat_model.py | 7 +- .../language_models/chat_models/test_base.py | 59 ++-- .../chat_models/test_rate_limiting.py | 9 + .../tests/unit_tests/prompts/test_chat.py | 67 ++-- .../unit_tests/runnables/test_context.py | 17 +- .../unit_tests/runnables/test_fallbacks.py | 12 +- .../unit_tests/runnables/test_runnable.py | 307 +++++++++++------ .../runnables/test_runnable_events_v1.py | 10 +- .../runnables/test_runnable_events_v2.py | 25 +- .../runnables/test_tracing_interops.py | 325 ++++++++++-------- libs/core/tests/unit_tests/test_setup.py | 15 + .../unit_tests/tracers/test_memory_stream.py | 13 +- .../unit_tests/vectorstores/test_in_memory.py | 8 +- 16 files changed, 588 insertions(+), 346 deletions(-) create mode 100644 libs/core/tests/unit_tests/test_setup.py diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index a27fb1e7a9a..e0e5b993195 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -221,6 +221,20 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.5)"] +[[package]] +name = "blockbuster" +version = "1.5.11" +description = "Utility to detect blocking calls in the async event loop" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blockbuster-1.5.11-py3-none-any.whl", hash = "sha256:34e56b2ff24c73d7b2857dcc20d49abe88f92a9cddd5f56432adeb3ef4aca9b8"}, + {file = "blockbuster-1.5.11.tar.gz", hash = "sha256:c5ed3da13216c80e26b755fb576c3d638d8125f893ba4709e56e28ecf3ee7254"}, +] + +[package.dependencies] +forbiddenfruit = ">=0.1.4" + [[package]] name = "certifi" version = "2024.12.14" @@ -553,6 +567,16 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "forbiddenfruit" +version = "0.1.4" +description = "Patch python built-in objects" +optional = false +python-versions = "*" +files = [ + {file = "forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -1188,7 +1212,7 @@ files = [ [[package]] name = "langchain-tests" -version = "0.3.9" +version = "0.3.10" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.9,<4.0" @@ -1197,7 +1221,7 @@ develop = true [package.dependencies] httpx = ">=0.25.0,<1" -langchain-core = "^0.3.31" +langchain-core = "^0.3.33" numpy = [ {version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, @@ -3247,4 +3271,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "2a262498da93ae3991e5eade787affab38f1ddef5212cb745eafdd0a40f0e986" +content-hash = "7a90068bfdba1760bd1e80d65c8f83b7d940f3c49e463cc793d63f1be518eb72" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 8cc2b15a7bd..6328c7fa87f 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -118,6 +118,7 @@ grandalf = "^0.8" responses = "^0.25.0" pytest-socket = "^0.7.0" pytest-xdist = "^3.6.1" +blockbuster = "~1.5.11" [[tool.poetry.group.test.dependencies.numpy]] version = "^1.24.0" python = "<3.12" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 29819a80669..6438c303740 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,14 +1,41 @@ """Configuration for unit tests.""" -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from importlib import util from uuid import UUID import pytest +from blockbuster import BlockBuster, blockbuster_ctx from pytest import Config, Function, Parser from pytest_mock import MockerFixture +@pytest.fixture(autouse=True) +def blockbuster() -> Iterator[BlockBuster]: + with blockbuster_ctx("langchain_core") as bb: + for func in ["os.stat", "os.path.abspath"]: + ( + bb.functions[func] + .can_block_in("langchain_core/_api/internal.py", "is_caller_internal") + .can_block_in("langchain_core/runnables/base.py", "__repr__") + .can_block_in( + "langchain_core/beta/runnables/context.py", "aconfig_with_context" + ) + ) + + for func in ["os.stat", "io.TextIOWrapper.read"]: + bb.functions[func].can_block_in( + "langsmith/client.py", "_default_retry_config" + ) + + for bb_function in bb.functions.values(): + bb_function.can_block_in( + "freezegun/api.py", "_get_cached_module_attributes" + ) + + yield bb + + def pytest_addoption(parser: Parser) -> None: """Add custom command line options to pytest.""" parser.addoption( diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7502e17c50f..5d7d4525e3f 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -191,7 +191,12 @@ async def test_callback_handlers() -> None: model = GenericFakeChatModel(messages=infinite_cycle) tokens: list[str] = [] # New model - results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + results = [ + chunk + async for chunk in model.astream( + "meow", {"callbacks": [MyCustomAsyncHandler(tokens)]} + ) + ] assert results == [ _any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content=" "), diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index bee3a0783af..420b60cebf3 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -18,6 +18,7 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult +from langchain_core.tracers import LogStreamCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler @@ -303,39 +304,48 @@ class StreamingModel(NoStreamingModel): @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming( +def test_disable_streaming( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" expected = "invoke" if disable_streaming is True else "stream" assert next(model.stream([])).content == expected - async for c in model.astream([]): - assert c.content == expected - break + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == expected + ) + + expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" + assert next(model.stream([], tools=[{"type": "function"}])).content == expected assert ( model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + [], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}] ).content == expected ) + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = StreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + + expected = "invoke" if disable_streaming is True else "stream" + async for c in model.astream([]): + assert c.content == expected + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == expected expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" - assert next(model.stream([], tools=[{"type": "function"}])).content == expected async for c in model.astream([], tools=[{}]): assert c.content == expected break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] - ).content - == expected - ) assert ( await model.ainvoke( [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] @@ -344,26 +354,31 @@ async def test_disable_streaming( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming_no_streaming_model( +def test_disable_streaming_no_streaming_model( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" assert next(model.stream([])).content == "invoke" + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == "invoke" + ) + assert next(model.stream([], tools=[{}])).content == "invoke" + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_no_streaming_model_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = NoStreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" async for c in model.astream([]): assert c.content == "invoke" break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} - ).content - == "invoke" - ) assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == "invoke" - assert next(model.stream([], tools=[{}])).content == "invoke" async for c in model.astream([], tools=[{}]): assert c.content == "invoke" break diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py index bd2d960e10e..ee6eefba927 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -1,11 +1,20 @@ import time from typing import Optional as Optional +import pytest +from blockbuster import BlockBuster + from langchain_core.caches import InMemoryCache from langchain_core.language_models import GenericFakeChatModel from langchain_core.rate_limiters import InMemoryRateLimiter +@pytest.fixture(autouse=True) +def deactivate_blockbuster(blockbuster: BlockBuster) -> None: + # Deactivate BlockBuster to not disturb the rate limiter timings + blockbuster.deactivate() + + def test_rate_limit_invoke() -> None: """Add rate limiter.""" model = GenericFakeChatModel( diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index cad31d03ef9..abc0ac4f326 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,5 +1,3 @@ -import base64 -import tempfile import warnings from pathlib import Path from typing import Any, Union, cast @@ -727,44 +725,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: """Verify that we cannot pass `path` for an image as a variable.""" in_mem = "base64mem" - in_file_data = "base64file01" - with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file: - temp_file.write(base64.b64decode(in_file_data)) - temp_file.flush() - - template = ChatPromptTemplate.from_messages( - [ - ("system", "You are an AI assistant named {name}."), - ( - "human", - [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": "data:image/jpeg;base64,{in_mem}", - }, - { - "type": "image_url", - "image_url": {"path": "{file_path}"}, - }, - ], - ), - ] + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{in_mem}", + }, + { + "type": "image_url", + "image_url": {"path": "{file_path}"}, + }, + ], + ), + ] + ) + with pytest.raises(ValueError): + template.format_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", ) - with pytest.raises(ValueError): - template.format_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) - with pytest.raises(ValueError): - await template.aformat_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) + with pytest.raises(ValueError): + await template.aformat_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", + ) def test_messages_placeholder() -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index c00eb999424..cb8de2dd808 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Callable, NamedTuple, Union import pytest @@ -330,19 +331,26 @@ test_cases = [ @pytest.mark.parametrize("runnable, cases", test_cases) -async def test_context_runnables( +def test_context_runnables( runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output - assert await runnable.ainvoke(cases[1].input) == cases[1].output assert runnable.batch([case.input for case in cases]) == [ case.output for case in cases ] + assert add(runnable.stream(cases[0].input)) == cases[0].output + + +@pytest.mark.parametrize("runnable, cases", test_cases) +async def test_context_runnables_async( + runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] +) -> None: + runnable = runnable if isinstance(runnable, Runnable) else runnable() + assert await runnable.ainvoke(cases[1].input) == cases[1].output assert await runnable.abatch([case.input for case in cases]) == [ case.output for case in cases ] - assert add(runnable.stream(cases[0].input)) == cases[0].output assert await aadd(runnable.astream(cases[1].input)) == cases[1].output @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None: "prompt": Context.getter("prompt"), } ) - - chunks = list(chain.stream({"foo": "foo", "bar": "bar"})) + chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"})) achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] for c in chunks: assert c in achunks diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 1826f883e7b..330adfa2089 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -82,17 +82,25 @@ def chain_pass_exceptions() -> Runnable: "runnable", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks( +def test_fallbacks( runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion ) -> None: runnable = request.getfixturevalue(runnable) assert runnable.invoke("hello") == "bar" assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(runnable.stream("hello")) == ["bar"] + assert dumps(runnable, pretty=True) == snapshot + + +@pytest.mark.parametrize( + "runnable", + ["llm", "llm_multi", "chain", "chain_pass_exceptions"], +) +async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None: + runnable = request.getfixturevalue(runnable) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") - assert dumps(runnable, pretty=True) == snapshot def _runnable(inputs: dict) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 1bd21b9d560..01cd8a026c5 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,3 +1,4 @@ +import asyncio import sys import uuid import warnings @@ -1008,6 +1009,73 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: ) +def test_passthrough_tap(mocker: MockerFixture) -> None: + fake = FakeRunnable() + mock = mocker.Mock() + + seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + + assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ + 5, + 6, + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert sorted( + a + for a in seq.batch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ) == [ + (0, 5), + (1, 6), + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert list( + seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") + ) == [5] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + async def test_passthrough_tap_async(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() @@ -1079,67 +1147,6 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mocker.call("hello", my_kwarg="value"), mocker.call(5), ] - mock.reset_mock() - - assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] - assert mock.call_args_list == [ - mocker.call("hello", my_kwarg="value"), - mocker.call(5), - ] - mock.reset_mock() - - assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ - 5, - 6, - ] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert sorted( - a - for a in seq.batch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) - ) == [ - (0, 5), - (1, 6), - ] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert list( - seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") - ) == [5] - assert mock.call_args_list == [ - mocker.call("hello", my_kwarg="value"), - mocker.call(5), - ] - mock.reset_mock() async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: @@ -1170,7 +1177,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: spy.reset_mock() -async def test_with_config(mocker: MockerFixture) -> None: +def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1276,7 +1283,11 @@ async def test_with_config(mocker: MockerFixture) -> None: for i, call in enumerate(spy.call_args_list): assert call.args[0] == ("hello" if i == 0 else "wooorld") assert call.args[1].get("tags") == ["a-tag"] - spy.reset_mock() + + +async def test_with_config_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") handler = ConsoleCallbackHandler() assert ( @@ -1372,7 +1383,7 @@ async def test_with_config(mocker: MockerFixture) -> None: ) -async def test_default_method_implementations(mocker: MockerFixture) -> None: +def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1413,7 +1424,11 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: for call in spy.call_args_list: assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} - spy.reset_mock() + + +async def test_default_method_implementations_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert spy.call_args_list == [ @@ -1442,7 +1457,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: } -async def test_prompt() -> None: +def test_prompt() -> None: prompt = ChatPromptTemplate.from_messages( messages=[ SystemMessage(content="You are a nice assistant."), @@ -1475,6 +1490,21 @@ async def test_prompt() -> None: assert [*prompt.stream({"question": "What is your name?"})] == [expected] + +async def test_prompt_async() -> None: + prompt = ChatPromptTemplate.from_messages( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + expected = ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert await prompt.ainvoke({"question": "What is your name?"}) == expected assert await prompt.abatch( @@ -2770,9 +2800,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> @freeze_time("2023-01-01") -async def test_router_runnable( - mocker: MockerFixture, snapshot: SnapshotAssertion -) -> None: +def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: chain1: Runnable = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) @@ -2797,17 +2825,6 @@ async def test_router_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [ - {"key": "math", "question": "2 + 2"}, - {"key": "english", "question": "2 + 2"}, - ] - ) - assert result2 == ["4", "2"] - # Test invoke router_spy = mocker.spy(router.__class__, "invoke") tracer = FakeTracer() @@ -2827,8 +2844,33 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +async def test_router_runnable_async() -> None: + chain1: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + chain2: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + chain: Runnable = { + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] + ) + assert result2 == ["4", "2"] + + @freeze_time("2023-01-01") -async def test_higher_order_lambda_runnable( +def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: math_chain: Runnable = ChatPromptTemplate.from_template( @@ -2865,17 +2907,6 @@ async def test_higher_order_lambda_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [ - {"key": "math", "question": "2 + 2"}, - {"key": "english", "question": "2 + 2"}, - ] - ) - assert result2 == ["4", "2"] - # Test invoke math_spy = mocker.spy(math_chain.__class__, "invoke") tracer = FakeTracer() @@ -2897,6 +2928,41 @@ async def test_higher_order_lambda_runnable( assert math_run.name == "RunnableSequence" assert len(math_run.child_runs) == 3 + +async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: + math_chain: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map: Runnable = RunnableParallel( + key=lambda x: x["key"], + input={"question": lambda x: x["question"]}, + ) + + def router(input: dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) + + chain: Runnable = input_map | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] + ) + assert result2 == ["4", "2"] + # Test ainvoke async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": @@ -4651,7 +4717,7 @@ async def test_tool_from_runnable() -> None: } -async def test_runnable_gen() -> None: +def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" def gen(input: Iterator[Any]) -> Iterator[int]: @@ -4671,6 +4737,10 @@ async def test_runnable_gen() -> None: assert list(runnable.stream(None)) == [1, 2, 3] assert runnable.batch([None, None]) == [6, 6] + +async def test_runnable_gen_async() -> None: + """Test that a generator can be used as a runnable.""" + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 @@ -4693,14 +4763,14 @@ async def test_runnable_gen() -> None: assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] assert await arunnablecallable.abatch([None, None]) == [6, 6] with pytest.raises(NotImplementedError): - arunnablecallable.invoke(None) + await asyncio.to_thread(arunnablecallable.invoke, None) with pytest.raises(NotImplementedError): - arunnablecallable.stream(None) + await asyncio.to_thread(arunnablecallable.stream, None) with pytest.raises(NotImplementedError): - arunnablecallable.batch([None, None]) + await asyncio.to_thread(arunnablecallable.batch, [None, None]) -async def test_runnable_gen_context_config() -> None: +def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context. """ @@ -4769,9 +4839,16 @@ async def test_runnable_gen_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_gen_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield await fake.ainvoke("a") @@ -4835,7 +4912,7 @@ async def test_runnable_gen_context_config() -> None: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_iter_context_config() -> None: +def test_runnable_iter_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context. """ @@ -4888,9 +4965,16 @@ async def test_runnable_iter_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_iter_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def agen(input: str) -> AsyncIterator[int]: @@ -4952,7 +5036,7 @@ async def test_runnable_iter_context_config() -> None: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_lambda_context_config() -> None: +def test_runnable_lambda_context_config() -> None: """Test that a function can call other runnables with config propagated from the context. """ @@ -5003,9 +5087,16 @@ async def test_runnable_lambda_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_lambda_context_config_async() -> None: + """Test that a function can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def afun(input: str) -> int: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 59c5e765e23..9aab4eef34f 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -1,5 +1,6 @@ """Module that contains tests for runnable.astream_events API.""" +import asyncio import sys from collections.abc import AsyncIterator, Sequence from itertools import cycle @@ -1957,9 +1958,12 @@ async def test_runnable_with_message_history() -> None: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 698a4c4ddab..c1efd9be503 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -13,6 +13,7 @@ from typing import ( ) import pytest +from blockbuster import BlockBuster from pydantic import BaseModel from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks @@ -38,7 +39,9 @@ from langchain_core.runnables import ( chain, ensure_config, ) -from langchain_core.runnables.config import get_callback_manager_for_config +from langchain_core.runnables.config import ( + get_async_callback_manager_for_config, +) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output @@ -1923,9 +1926,12 @@ async def test_runnable_with_message_history() -> None: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), @@ -1995,8 +2001,9 @@ EXPECTED_EVENTS = [ ] -async def test_sync_in_async_stream_lambdas() -> None: +async def test_sync_in_async_stream_lambdas(blockbuster: BlockBuster) -> None: """Test invoking nested runnable lambda.""" + blockbuster.deactivate() def add_one(x: int) -> int: return x + 1 @@ -2085,8 +2092,8 @@ class StreamingRunnable(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), @@ -2109,9 +2116,9 @@ class StreamingRunnable(Runnable[Input, Output]): final_output = element # set final channel values as run output - run_manager.on_chain_end(final_output) + await run_manager.on_chain_end(final_output) except BaseException as e: - run_manager.on_chain_error(e) + await run_manager.on_chain_error(e) raise diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 3409d04f234..77922a5616c 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import json import sys import uuid -from collections.abc import AsyncGenerator, Generator -from typing import Any +from collections.abc import AsyncGenerator, Coroutine, Generator +from inspect import isasyncgenfunction +from typing import Any, Callable, Optional from unittest.mock import MagicMock, patch import pytest @@ -12,6 +15,7 @@ from langsmith.run_trees import RunTree from langsmith.utils import get_env_var from typing_extensions import Literal +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.runnables.base import RunnableLambda, RunnableParallel from langchain_core.tracers.langchain import LangChainTracer @@ -35,6 +39,17 @@ def _get_posts(client: Client) -> list: return posts +def _create_tracer_with_mocked_client( + project_name: Optional[str] = None, + tags: Optional[list[str]] = None, +) -> LangChainTracer: + mock_session = MagicMock() + mock_client_ = Client( + session=mock_session, api_key="test", auto_batch_tracing=False + ) + return LangChainTracer(client=mock_client_, project_name=project_name, tags=tags) + + def test_tracing_context() -> None: mock_session = MagicMock() mock_client_ = Client( @@ -56,12 +71,8 @@ def test_tracing_context() -> None: def test_config_traceable_handoff() -> None: get_env_var.cache_clear() - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer( - client=mock_client_, project_name="another-flippin-project", tags=["such-a-tag"] + tracer = _create_tracer_with_mocked_client( + project_name="another-flippin-project", tags=["such-a-tag"] ) @traceable @@ -100,7 +111,7 @@ def test_config_traceable_handoff() -> None: my_parent_runnable = RunnableLambda(my_parent_function) assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) assert all(post["session_name"] == "another-flippin-project" for post in posts) # There should have been 6 runs created, # one for each function invocation @@ -143,11 +154,7 @@ def test_config_traceable_handoff() -> None: sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" ) async def test_config_traceable_async_handoff() -> None: - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer(client=mock_client_) + tracer = _create_tracer_with_mocked_client() @traceable def my_great_great_grandchild_function(a: int) -> int: @@ -175,7 +182,7 @@ async def test_config_traceable_async_handoff() -> None: my_parent_runnable = RunnableLambda(my_parent_function) # type: ignore result = await my_parent_runnable.ainvoke(1, {"callbacks": [tracer]}) assert result == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) # There should have been 6 runs created, # one for each function invocation assert len(posts) == 6 @@ -245,144 +252,172 @@ def test_tracing_enable_disable( assert not mock_posts -@pytest.mark.parametrize( - "method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"] -) -async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None: - if method.startswith("a") and sys.version_info < (3, 11): - pytest.skip("Asyncio context vars require Python 3.11+") - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False +class TestRunnableSequenceParallelTraceNesting: + @pytest.fixture(autouse=True) + def _setup(self) -> None: + self.tracer = _create_tracer_with_mocked_client() + + @staticmethod + def _create_parent( + other_thing: Callable[ + [int], Generator[int, None, None] | AsyncGenerator[int, None] + ], + ) -> RunnableLambda: + @RunnableLambda + def my_child_function(a: int) -> int: + return a + 2 + + parallel = RunnableParallel( + chain_result=my_child_function.with_config(tags=["atag"]), + other_thing=other_thing, + ) + + def before(x: int) -> int: + return x + + def after(x: dict) -> int: + return x["chain_result"] + + sequence = before | parallel | after + if isasyncgenfunction(other_thing): + + @RunnableLambda # type: ignore + async def parent(a: int) -> int: + return await sequence.ainvoke(a) + + else: + + @RunnableLambda + def parent(a: int) -> int: + return sequence.invoke(a) + + return parent + + def _check_posts(self) -> None: + posts = _get_posts(self.tracer.client) + name_order = [ + "parent", + "RunnableSequence", + "before", + "RunnableParallel", + ["my_child_function", "other_thing"], + "after", + ] + expected_parents = { + "parent": None, + "RunnableSequence": "parent", + "before": "RunnableSequence", + "RunnableParallel": "RunnableSequence", + "my_child_function": "RunnableParallel", + "other_thing": "RunnableParallel", + "after": "RunnableSequence", + } + assert len(posts) == sum( + 1 if isinstance(n, str) else len(n) for n in name_order + ) + prev_dotted_order = None + dotted_order_map = {} + id_map = {} + parent_id_map = {} + i = 0 + for name in name_order: + if isinstance(name, list): + for n in name: + matching_post = next( + p for p in posts[i : i + len(name)] if p["name"] == n + ) + assert matching_post + dotted_order = matching_post["dotted_order"] + if prev_dotted_order is not None: + assert dotted_order > prev_dotted_order + dotted_order_map[n] = dotted_order + id_map[n] = matching_post["id"] + parent_id_map[n] = matching_post.get("parent_run_id") + i += len(name) + continue + else: + assert posts[i]["name"] == name + dotted_order = posts[i]["dotted_order"] + if prev_dotted_order is not None and not str( + expected_parents[name] + ).startswith("RunnableParallel"): + assert dotted_order > prev_dotted_order, ( + f"{name} not after {name_order[i - 1]}" + ) + prev_dotted_order = dotted_order + if name in dotted_order_map: + msg = f"Duplicate name {name}" + raise ValueError(msg) + dotted_order_map[name] = dotted_order + id_map[name] = posts[i]["id"] + parent_id_map[name] = posts[i].get("parent_run_id") + i += 1 + + # Now check the dotted orders + for name, parent_ in expected_parents.items(): + dotted_order = dotted_order_map[name] + if parent_ is not None: + parent_dotted_order = dotted_order_map[parent_] + assert dotted_order.startswith(parent_dotted_order), ( + f"{name}, {parent_dotted_order} not in {dotted_order}" + ) + assert str(parent_id_map[name]) == str(id_map[parent_]) + else: + assert dotted_order.split(".")[0] == dotted_order + + @pytest.mark.parametrize( + "method", + [ + lambda parent, cb: parent.invoke(1, {"callbacks": cb}), + lambda parent, cb: list(parent.stream(1, {"callbacks": cb}))[-1], + lambda parent, cb: parent.batch([1], {"callbacks": cb})[0], + ], + ids=["invoke", "stream", "batch"], ) - tracer = LangChainTracer(client=mock_client_) - - @RunnableLambda - def my_child_function(a: int) -> int: - return a + 2 - - if method.startswith("a"): - - async def other_thing(a: int) -> AsyncGenerator[int, None]: - yield 1 - - else: - + def test_sync( + self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int] + ) -> None: def other_thing(a: int) -> Generator[int, None, None]: # type: ignore yield 1 - parallel = RunnableParallel( - chain_result=my_child_function.with_config(tags=["atag"]), - other_thing=other_thing, + parent = self._create_parent(other_thing) + + # Now run the chain and check the resulting posts + assert method(parent, [self.tracer]) == 3 + + self._check_posts() + + @staticmethod + async def ainvoke(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return await parent.ainvoke(1, {"callbacks": cb}) + + @staticmethod + async def astream(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return [res async for res in parent.astream(1, {"callbacks": cb})][-1] + + @staticmethod + async def abatch(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return (await parent.abatch([1], {"callbacks": cb}))[0] + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" ) + @pytest.mark.parametrize("method", [ainvoke, astream, abatch]) + async def test_async( + self, + method: Callable[ + [RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int] + ], + ) -> None: + async def other_thing(a: int) -> AsyncGenerator[int, None]: + yield 1 - def before(x: int) -> int: - return x + parent = self._create_parent(other_thing) - def after(x: dict) -> int: - return x["chain_result"] + # Now run the chain and check the resulting posts + assert await method(parent, [self.tracer]) == 3 - sequence = before | parallel | after - if method.startswith("a"): - - @RunnableLambda # type: ignore - async def parent(a: int) -> int: - return await sequence.ainvoke(a) - - else: - - @RunnableLambda - def parent(a: int) -> int: - return sequence.invoke(a) - - # Now run the chain and check the resulting posts - cb = [tracer] - if method == "invoke": - res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore - elif method == "ainvoke": - res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore - elif method == "stream": - results = list(parent.stream(1, {"callbacks": cb})) # type: ignore - res = results[-1] - elif method == "astream": - results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore - res = results[-1] - elif method == "batch": - res = parent.batch([1], {"callbacks": cb})[0] # type: ignore - elif method == "abatch": - res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore - else: - msg = f"Unknown method {method}" - raise ValueError(msg) - assert res == 3 - posts = _get_posts(mock_client_) - name_order = [ - "parent", - "RunnableSequence", - "before", - "RunnableParallel", - ["my_child_function", "other_thing"], - "after", - ] - expected_parents = { - "parent": None, - "RunnableSequence": "parent", - "before": "RunnableSequence", - "RunnableParallel": "RunnableSequence", - "my_child_function": "RunnableParallel", - "other_thing": "RunnableParallel", - "after": "RunnableSequence", - } - assert len(posts) == sum(1 if isinstance(n, str) else len(n) for n in name_order) - prev_dotted_order = None - dotted_order_map = {} - id_map = {} - parent_id_map = {} - i = 0 - for name in name_order: - if isinstance(name, list): - for n in name: - matching_post = next( - p for p in posts[i : i + len(name)] if p["name"] == n - ) - assert matching_post - dotted_order = matching_post["dotted_order"] - if prev_dotted_order is not None: - assert dotted_order > prev_dotted_order - dotted_order_map[n] = dotted_order - id_map[n] = matching_post["id"] - parent_id_map[n] = matching_post.get("parent_run_id") - i += len(name) - continue - else: - assert posts[i]["name"] == name - dotted_order = posts[i]["dotted_order"] - if prev_dotted_order is not None and not str( - expected_parents[name] - ).startswith("RunnableParallel"): - assert dotted_order > prev_dotted_order, ( - f"{name} not after {name_order[i - 1]}" - ) - prev_dotted_order = dotted_order - if name in dotted_order_map: - msg = f"Duplicate name {name}" - raise ValueError(msg) - dotted_order_map[name] = dotted_order - id_map[name] = posts[i]["id"] - parent_id_map[name] = posts[i].get("parent_run_id") - i += 1 - - # Now check the dotted orders - for name, parent_ in expected_parents.items(): - dotted_order = dotted_order_map[name] - if parent_ is not None: - parent_dotted_order = dotted_order_map[parent_] - assert dotted_order.startswith(parent_dotted_order), ( - f"{name}, {parent_dotted_order} not in {dotted_order}" - ) - assert str(parent_id_map[name]) == str(id_map[parent_]) - else: - assert dotted_order.split(".")[0] == dotted_order + self._check_posts() @pytest.mark.parametrize("parent_type", ("ls", "lc")) diff --git a/libs/core/tests/unit_tests/test_setup.py b/libs/core/tests/unit_tests/test_setup.py new file mode 100644 index 00000000000..1df3c73a252 --- /dev/null +++ b/libs/core/tests/unit_tests/test_setup.py @@ -0,0 +1,15 @@ +import time + +import pytest +from blockbuster import BlockingError + +from langchain_core import sys_info + + +async def test_blockbuster_setup() -> None: + """Check if blockbuster is correctly setup.""" + # Blocking call outside of langchain_core is allowed. + time.sleep(0.01) # noqa: ASYNC251 + with pytest.raises(BlockingError): + # Blocking call from langchain_core raises BlockingError. + sys_info.print_sys_info() diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index 74541f42adb..0371d7ddcdb 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -2,7 +2,6 @@ import asyncio import math import time from collections.abc import AsyncIterator -from concurrent.futures import ThreadPoolExecutor from langchain_core.tracers.memory_stream import _MemoryStream @@ -70,7 +69,7 @@ async def test_queue_for_streaming_via_sync_call() -> None: """Produce items with slight delay.""" tic = time.time() for i in range(3): - await asyncio.sleep(0.10) + await asyncio.sleep(0.2) toc = time.time() await writer.send( { @@ -93,9 +92,11 @@ async def test_queue_for_streaming_via_sync_call() -> None: **item, } - with ThreadPoolExecutor() as executor: - executor.submit(sync_call) - items = [item async for item in consumer()] + task = asyncio.create_task(asyncio.to_thread(sync_call)) + items = [item async for item in consumer()] + await task + + assert len(items) == 3 for item in items: delta_time = item["receive_time"] - item["produce_time"] @@ -107,7 +108,7 @@ async def test_queue_for_streaming_via_sync_call() -> None: # To verify that the producer and consumer are running in parallel, we # expect the delta_time to be smaller than the sleep delay in the producer # * # of items = 30 ms - assert math.isclose(delta_time, 0, abs_tol=0.010) is True, ( + assert math.isclose(delta_time, 0, abs_tol=0.020) is True, ( f"delta_time: {delta_time}" ) diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index 67ede1c1508..5ebaf8633d9 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -83,17 +83,17 @@ async def test_inmemory_mmr() -> None: assert output[1] == _any_id_document(page_content="fou") -async def test_inmemory_dump_load(tmp_path: Path) -> None: +def test_inmemory_dump_load(tmp_path: Path) -> None: """Test end to end construction and search.""" embedding = DeterministicFakeEmbedding(size=6) - store = await InMemoryVectorStore.afrom_texts(["foo", "bar", "baz"], embedding) - output = await store.asimilarity_search("foo", k=1) + store = InMemoryVectorStore.from_texts(["foo", "bar", "baz"], embedding) + output = store.similarity_search("foo", k=1) test_file = str(tmp_path / "test.json") store.dump(test_file) loaded_store = InMemoryVectorStore.load(test_file, embedding) - loaded_output = await loaded_store.asimilarity_search("foo", k=1) + loaded_output = loaded_store.similarity_search("foo", k=1) assert output == loaded_output