diff --git a/libs/core/poetry.lock b/libs/core/poetry.lock index a27fb1e7a9a..e0e5b993195 100644 --- a/libs/core/poetry.lock +++ b/libs/core/poetry.lock @@ -221,6 +221,20 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.5)"] +[[package]] +name = "blockbuster" +version = "1.5.11" +description = "Utility to detect blocking calls in the async event loop" +optional = false +python-versions = ">=3.8" +files = [ + {file = "blockbuster-1.5.11-py3-none-any.whl", hash = "sha256:34e56b2ff24c73d7b2857dcc20d49abe88f92a9cddd5f56432adeb3ef4aca9b8"}, + {file = "blockbuster-1.5.11.tar.gz", hash = "sha256:c5ed3da13216c80e26b755fb576c3d638d8125f893ba4709e56e28ecf3ee7254"}, +] + +[package.dependencies] +forbiddenfruit = ">=0.1.4" + [[package]] name = "certifi" version = "2024.12.14" @@ -553,6 +567,16 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "forbiddenfruit" +version = "0.1.4" +description = "Patch python built-in objects" +optional = false +python-versions = "*" +files = [ + {file = "forbiddenfruit-0.1.4.tar.gz", hash = "sha256:e3f7e66561a29ae129aac139a85d610dbf3dd896128187ed5454b6421f624253"}, +] + [[package]] name = "fqdn" version = "1.5.1" @@ -1188,7 +1212,7 @@ files = [ [[package]] name = "langchain-tests" -version = "0.3.9" +version = "0.3.10" description = "Standard tests for LangChain implementations" optional = false python-versions = ">=3.9,<4.0" @@ -1197,7 +1221,7 @@ develop = true [package.dependencies] httpx = ">=0.25.0,<1" -langchain-core = "^0.3.31" +langchain-core = "^0.3.33" numpy = [ {version = ">=1.24.0,<2.0.0", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, @@ -3247,4 +3271,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "2a262498da93ae3991e5eade787affab38f1ddef5212cb745eafdd0a40f0e986" +content-hash = "7a90068bfdba1760bd1e80d65c8f83b7d940f3c49e463cc793d63f1be518eb72" diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 8cc2b15a7bd..6328c7fa87f 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -118,6 +118,7 @@ grandalf = "^0.8" responses = "^0.25.0" pytest-socket = "^0.7.0" pytest-xdist = "^3.6.1" +blockbuster = "~1.5.11" [[tool.poetry.group.test.dependencies.numpy]] version = "^1.24.0" python = "<3.12" diff --git a/libs/core/tests/unit_tests/conftest.py b/libs/core/tests/unit_tests/conftest.py index 29819a80669..6438c303740 100644 --- a/libs/core/tests/unit_tests/conftest.py +++ b/libs/core/tests/unit_tests/conftest.py @@ -1,14 +1,41 @@ """Configuration for unit tests.""" -from collections.abc import Sequence +from collections.abc import Iterator, Sequence from importlib import util from uuid import UUID import pytest +from blockbuster import BlockBuster, blockbuster_ctx from pytest import Config, Function, Parser from pytest_mock import MockerFixture +@pytest.fixture(autouse=True) +def blockbuster() -> Iterator[BlockBuster]: + with blockbuster_ctx("langchain_core") as bb: + for func in ["os.stat", "os.path.abspath"]: + ( + bb.functions[func] + .can_block_in("langchain_core/_api/internal.py", "is_caller_internal") + .can_block_in("langchain_core/runnables/base.py", "__repr__") + .can_block_in( + "langchain_core/beta/runnables/context.py", "aconfig_with_context" + ) + ) + + for func in ["os.stat", "io.TextIOWrapper.read"]: + bb.functions[func].can_block_in( + "langsmith/client.py", "_default_retry_config" + ) + + for bb_function in bb.functions.values(): + bb_function.can_block_in( + "freezegun/api.py", "_get_cached_module_attributes" + ) + + yield bb + + def pytest_addoption(parser: Parser) -> None: """Add custom command line options to pytest.""" parser.addoption( diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7502e17c50f..5d7d4525e3f 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -191,7 +191,12 @@ async def test_callback_handlers() -> None: model = GenericFakeChatModel(messages=infinite_cycle) tokens: list[str] = [] # New model - results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]})) + results = [ + chunk + async for chunk in model.astream( + "meow", {"callbacks": [MyCustomAsyncHandler(tokens)]} + ) + ] assert results == [ _any_id_ai_message_chunk(content="hello"), _any_id_ai_message_chunk(content=" "), diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index bee3a0783af..420b60cebf3 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -18,6 +18,7 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs.llm_result import LLMResult +from langchain_core.tracers import LogStreamCallbackHandler from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.context import collect_runs from langchain_core.tracers.event_stream import _AstreamEventsCallbackHandler @@ -303,39 +304,48 @@ class StreamingModel(NoStreamingModel): @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming( +def test_disable_streaming( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = StreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" expected = "invoke" if disable_streaming is True else "stream" assert next(model.stream([])).content == expected - async for c in model.astream([]): - assert c.content == expected - break + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == expected + ) + + expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" + assert next(model.stream([], tools=[{"type": "function"}])).content == expected assert ( model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} + [], config={"callbacks": [LogStreamCallbackHandler()]}, tools=[{}] ).content == expected ) + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = StreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" + + expected = "invoke" if disable_streaming is True else "stream" + async for c in model.astream([]): + assert c.content == expected + break assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == expected expected = "invoke" if disable_streaming in ("tool_calling", True) else "stream" - assert next(model.stream([], tools=[{"type": "function"}])).content == expected async for c in model.astream([], tools=[{}]): assert c.content == expected break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] - ).content - == expected - ) assert ( await model.ainvoke( [], config={"callbacks": [_AstreamEventsCallbackHandler()]}, tools=[{}] @@ -344,26 +354,31 @@ async def test_disable_streaming( @pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) -async def test_disable_streaming_no_streaming_model( +def test_disable_streaming_no_streaming_model( disable_streaming: Union[bool, Literal["tool_calling"]], ) -> None: model = NoStreamingModel(disable_streaming=disable_streaming) assert model.invoke([]).content == "invoke" - assert (await model.ainvoke([])).content == "invoke" assert next(model.stream([])).content == "invoke" + assert ( + model.invoke([], config={"callbacks": [LogStreamCallbackHandler()]}).content + == "invoke" + ) + assert next(model.stream([], tools=[{}])).content == "invoke" + + +@pytest.mark.parametrize("disable_streaming", [True, False, "tool_calling"]) +async def test_disable_streaming_no_streaming_model_async( + disable_streaming: Union[bool, Literal["tool_calling"]], +) -> None: + model = NoStreamingModel(disable_streaming=disable_streaming) + assert (await model.ainvoke([])).content == "invoke" async for c in model.astream([]): assert c.content == "invoke" break - assert ( - model.invoke( - [], config={"callbacks": [_AstreamEventsCallbackHandler()]} - ).content - == "invoke" - ) assert ( await model.ainvoke([], config={"callbacks": [_AstreamEventsCallbackHandler()]}) ).content == "invoke" - assert next(model.stream([], tools=[{}])).content == "invoke" async for c in model.astream([], tools=[{}]): assert c.content == "invoke" break diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py index bd2d960e10e..ee6eefba927 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_rate_limiting.py @@ -1,11 +1,20 @@ import time from typing import Optional as Optional +import pytest +from blockbuster import BlockBuster + from langchain_core.caches import InMemoryCache from langchain_core.language_models import GenericFakeChatModel from langchain_core.rate_limiters import InMemoryRateLimiter +@pytest.fixture(autouse=True) +def deactivate_blockbuster(blockbuster: BlockBuster) -> None: + # Deactivate BlockBuster to not disturb the rate limiter timings + blockbuster.deactivate() + + def test_rate_limit_invoke() -> None: """Add rate limiter.""" model = GenericFakeChatModel( diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index cad31d03ef9..abc0ac4f326 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,5 +1,3 @@ -import base64 -import tempfile import warnings from pathlib import Path from typing import Any, Union, cast @@ -727,44 +725,39 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: """Verify that we cannot pass `path` for an image as a variable.""" in_mem = "base64mem" - in_file_data = "base64file01" - with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file: - temp_file.write(base64.b64decode(in_file_data)) - temp_file.flush() - - template = ChatPromptTemplate.from_messages( - [ - ("system", "You are an AI assistant named {name}."), - ( - "human", - [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": "data:image/jpeg;base64,{in_mem}", - }, - { - "type": "image_url", - "image_url": {"path": "{file_path}"}, - }, - ], - ), - ] + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{in_mem}", + }, + { + "type": "image_url", + "image_url": {"path": "{file_path}"}, + }, + ], + ), + ] + ) + with pytest.raises(ValueError): + template.format_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", ) - with pytest.raises(ValueError): - template.format_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) - with pytest.raises(ValueError): - await template.aformat_messages( - name="R2D2", - in_mem=in_mem, - file_path=temp_file.name, - ) + with pytest.raises(ValueError): + await template.aformat_messages( + name="R2D2", + in_mem=in_mem, + file_path="some/path", + ) def test_messages_placeholder() -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index c00eb999424..cb8de2dd808 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Callable, NamedTuple, Union import pytest @@ -330,19 +331,26 @@ test_cases = [ @pytest.mark.parametrize("runnable, cases", test_cases) -async def test_context_runnables( +def test_context_runnables( runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output - assert await runnable.ainvoke(cases[1].input) == cases[1].output assert runnable.batch([case.input for case in cases]) == [ case.output for case in cases ] + assert add(runnable.stream(cases[0].input)) == cases[0].output + + +@pytest.mark.parametrize("runnable, cases", test_cases) +async def test_context_runnables_async( + runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] +) -> None: + runnable = runnable if isinstance(runnable, Runnable) else runnable() + assert await runnable.ainvoke(cases[1].input) == cases[1].output assert await runnable.abatch([case.input for case in cases]) == [ case.output for case in cases ] - assert add(runnable.stream(cases[0].input)) == cases[0].output assert await aadd(runnable.astream(cases[1].input)) == cases[1].output @@ -390,8 +398,7 @@ async def test_runnable_seq_streaming_chunks() -> None: "prompt": Context.getter("prompt"), } ) - - chunks = list(chain.stream({"foo": "foo", "bar": "bar"})) + chunks = await asyncio.to_thread(list, chain.stream({"foo": "foo", "bar": "bar"})) achunks = [c async for c in chain.astream({"foo": "foo", "bar": "bar"})] for c in chunks: assert c in achunks diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 1826f883e7b..330adfa2089 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -82,17 +82,25 @@ def chain_pass_exceptions() -> Runnable: "runnable", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks( +def test_fallbacks( runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion ) -> None: runnable = request.getfixturevalue(runnable) assert runnable.invoke("hello") == "bar" assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(runnable.stream("hello")) == ["bar"] + assert dumps(runnable, pretty=True) == snapshot + + +@pytest.mark.parametrize( + "runnable", + ["llm", "llm_multi", "chain", "chain_pass_exceptions"], +) +async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None: + runnable = request.getfixturevalue(runnable) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") - assert dumps(runnable, pretty=True) == snapshot def _runnable(inputs: dict) -> str: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 1bd21b9d560..01cd8a026c5 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -1,3 +1,4 @@ +import asyncio import sys import uuid import warnings @@ -1008,6 +1009,73 @@ def test_configurable_fields_example(snapshot: SnapshotAssertion) -> None: ) +def test_passthrough_tap(mocker: MockerFixture) -> None: + fake = FakeRunnable() + mock = mocker.Mock() + + seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + + assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ + 5, + 6, + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert sorted( + a + for a in seq.batch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ) == [ + (0, 5), + (1, 6), + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert list( + seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") + ) == [5] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + async def test_passthrough_tap_async(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() @@ -1079,67 +1147,6 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: mocker.call("hello", my_kwarg="value"), mocker.call(5), ] - mock.reset_mock() - - assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] - assert mock.call_args_list == [ - mocker.call("hello", my_kwarg="value"), - mocker.call(5), - ] - mock.reset_mock() - - assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ - 5, - 6, - ] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert sorted( - a - for a in seq.batch_as_completed( - ["hello", "byebye"], my_kwarg="value", return_exceptions=True - ) - ) == [ - (0, 5), - (1, 6), - ] - assert len(mock.call_args_list) == 4 - for call in [ - mocker.call("hello", my_kwarg="value"), - mocker.call("byebye", my_kwarg="value"), - mocker.call(5), - mocker.call(6), - ]: - assert call in mock.call_args_list - mock.reset_mock() - - assert list( - seq.stream("hello", {"metadata": {"key": "value"}}, my_kwarg="value") - ) == [5] - assert mock.call_args_list == [ - mocker.call("hello", my_kwarg="value"), - mocker.call(5), - ] - mock.reset_mock() async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: @@ -1170,7 +1177,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None: spy.reset_mock() -async def test_with_config(mocker: MockerFixture) -> None: +def test_with_config(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1276,7 +1283,11 @@ async def test_with_config(mocker: MockerFixture) -> None: for i, call in enumerate(spy.call_args_list): assert call.args[0] == ("hello" if i == 0 else "wooorld") assert call.args[1].get("tags") == ["a-tag"] - spy.reset_mock() + + +async def test_with_config_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") handler = ConsoleCallbackHandler() assert ( @@ -1372,7 +1383,7 @@ async def test_with_config(mocker: MockerFixture) -> None: ) -async def test_default_method_implementations(mocker: MockerFixture) -> None: +def test_default_method_implementations(mocker: MockerFixture) -> None: fake = FakeRunnable() spy = mocker.spy(fake, "invoke") @@ -1413,7 +1424,11 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: for call in spy.call_args_list: assert call.args[1].get("tags") == ["a-tag"] assert call.args[1].get("metadata") == {} - spy.reset_mock() + + +async def test_default_method_implementations_async(mocker: MockerFixture) -> None: + fake = FakeRunnable() + spy = mocker.spy(fake, "invoke") assert await fake.ainvoke("hello", config={"callbacks": []}) == 5 assert spy.call_args_list == [ @@ -1442,7 +1457,7 @@ async def test_default_method_implementations(mocker: MockerFixture) -> None: } -async def test_prompt() -> None: +def test_prompt() -> None: prompt = ChatPromptTemplate.from_messages( messages=[ SystemMessage(content="You are a nice assistant."), @@ -1475,6 +1490,21 @@ async def test_prompt() -> None: assert [*prompt.stream({"question": "What is your name?"})] == [expected] + +async def test_prompt_async() -> None: + prompt = ChatPromptTemplate.from_messages( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessagePromptTemplate.from_template("{question}"), + ] + ) + expected = ChatPromptValue( + messages=[ + SystemMessage(content="You are a nice assistant."), + HumanMessage(content="What is your name?"), + ] + ) + assert await prompt.ainvoke({"question": "What is your name?"}) == expected assert await prompt.abatch( @@ -2770,9 +2800,7 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> @freeze_time("2023-01-01") -async def test_router_runnable( - mocker: MockerFixture, snapshot: SnapshotAssertion -) -> None: +def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: chain1: Runnable = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) @@ -2797,17 +2825,6 @@ async def test_router_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [ - {"key": "math", "question": "2 + 2"}, - {"key": "english", "question": "2 + 2"}, - ] - ) - assert result2 == ["4", "2"] - # Test invoke router_spy = mocker.spy(router.__class__, "invoke") tracer = FakeTracer() @@ -2827,8 +2844,33 @@ async def test_router_runnable( assert len(router_run.child_runs) == 2 +async def test_router_runnable_async() -> None: + chain1: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + chain2: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + chain: Runnable = { + "key": lambda x: x["key"], + "input": {"question": lambda x: x["question"]}, + } | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] + ) + assert result2 == ["4", "2"] + + @freeze_time("2023-01-01") -async def test_higher_order_lambda_runnable( +def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: math_chain: Runnable = ChatPromptTemplate.from_template( @@ -2865,17 +2907,6 @@ async def test_higher_order_lambda_runnable( ) assert result2 == ["4", "2"] - result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) - assert result == "4" - - result2 = await chain.abatch( - [ - {"key": "math", "question": "2 + 2"}, - {"key": "english", "question": "2 + 2"}, - ] - ) - assert result2 == ["4", "2"] - # Test invoke math_spy = mocker.spy(math_chain.__class__, "invoke") tracer = FakeTracer() @@ -2897,6 +2928,41 @@ async def test_higher_order_lambda_runnable( assert math_run.name == "RunnableSequence" assert len(math_run.child_runs) == 3 + +async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: + math_chain: Runnable = ChatPromptTemplate.from_template( + "You are a math genius. Answer the question: {question}" + ) | FakeListLLM(responses=["4"]) + english_chain: Runnable = ChatPromptTemplate.from_template( + "You are an english major. Answer the question: {question}" + ) | FakeListLLM(responses=["2"]) + input_map: Runnable = RunnableParallel( + key=lambda x: x["key"], + input={"question": lambda x: x["question"]}, + ) + + def router(input: dict[str, Any]) -> Runnable: + if input["key"] == "math": + return itemgetter("input") | math_chain + elif input["key"] == "english": + return itemgetter("input") | english_chain + else: + msg = f"Unknown key: {input['key']}" + raise ValueError(msg) + + chain: Runnable = input_map | router + + result = await chain.ainvoke({"key": "math", "question": "2 + 2"}) + assert result == "4" + + result2 = await chain.abatch( + [ + {"key": "math", "question": "2 + 2"}, + {"key": "english", "question": "2 + 2"}, + ] + ) + assert result2 == ["4", "2"] + # Test ainvoke async def arouter(input: dict[str, Any]) -> Runnable: if input["key"] == "math": @@ -4651,7 +4717,7 @@ async def test_tool_from_runnable() -> None: } -async def test_runnable_gen() -> None: +def test_runnable_gen() -> None: """Test that a generator can be used as a runnable.""" def gen(input: Iterator[Any]) -> Iterator[int]: @@ -4671,6 +4737,10 @@ async def test_runnable_gen() -> None: assert list(runnable.stream(None)) == [1, 2, 3] assert runnable.batch([None, None]) == [6, 6] + +async def test_runnable_gen_async() -> None: + """Test that a generator can be used as a runnable.""" + async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield 1 yield 2 @@ -4693,14 +4763,14 @@ async def test_runnable_gen() -> None: assert [p async for p in arunnablecallable.astream(None)] == [1, 2, 3] assert await arunnablecallable.abatch([None, None]) == [6, 6] with pytest.raises(NotImplementedError): - arunnablecallable.invoke(None) + await asyncio.to_thread(arunnablecallable.invoke, None) with pytest.raises(NotImplementedError): - arunnablecallable.stream(None) + await asyncio.to_thread(arunnablecallable.stream, None) with pytest.raises(NotImplementedError): - arunnablecallable.batch([None, None]) + await asyncio.to_thread(arunnablecallable.batch, [None, None]) -async def test_runnable_gen_context_config() -> None: +def test_runnable_gen_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context. """ @@ -4769,9 +4839,16 @@ async def test_runnable_gen_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_gen_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) async def agen(input: AsyncIterator[Any]) -> AsyncIterator[int]: yield await fake.ainvoke("a") @@ -4835,7 +4912,7 @@ async def test_runnable_gen_context_config() -> None: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_iter_context_config() -> None: +def test_runnable_iter_context_config() -> None: """Test that a generator can call other runnables with config propagated from the context. """ @@ -4888,9 +4965,16 @@ async def test_runnable_iter_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_iter_context_config_async() -> None: + """Test that a generator can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def agen(input: str) -> AsyncIterator[int]: @@ -4952,7 +5036,7 @@ async def test_runnable_iter_context_config() -> None: assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] -async def test_runnable_lambda_context_config() -> None: +def test_runnable_lambda_context_config() -> None: """Test that a function can call other runnables with config propagated from the context. """ @@ -5003,9 +5087,16 @@ async def test_runnable_lambda_context_config() -> None: assert [r.inputs["input"] for r in tracer.runs[1].child_runs] == ["a", "aa", "aaa"] assert [(r.outputs or {})["output"] for r in tracer.runs[1].child_runs] == [1, 2, 3] - if sys.version_info < (3, 11): - # Python 3.10 and below don't support running async tasks in a specific context - return + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="Python 3.10 and below don't support running async tasks in a specific context", +) +async def test_runnable_lambda_context_config_async() -> None: + """Test that a function can call other runnables with config + propagated from the context.""" + + fake = RunnableLambda(len) @chain async def afun(input: str) -> int: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 59c5e765e23..9aab4eef34f 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -1,5 +1,6 @@ """Module that contains tests for runnable.astream_events API.""" +import asyncio import sys from collections.abc import AsyncIterator, Sequence from itertools import cycle @@ -1957,9 +1958,12 @@ async def test_runnable_with_message_history() -> None: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index 698a4c4ddab..c1efd9be503 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -13,6 +13,7 @@ from typing import ( ) import pytest +from blockbuster import BlockBuster from pydantic import BaseModel from langchain_core.callbacks import CallbackManagerForRetrieverRun, Callbacks @@ -38,7 +39,9 @@ from langchain_core.runnables import ( chain, ensure_config, ) -from langchain_core.runnables.config import get_callback_manager_for_config +from langchain_core.runnables.config import ( + get_async_callback_manager_for_config, +) from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.schema import StreamEvent from langchain_core.runnables.utils import Input, Output @@ -1923,9 +1926,12 @@ async def test_runnable_with_message_history() -> None: ] } - with_message_history.with_config( - {"configurable": {"session_id": "session-123"}} - ).invoke({"question": "meow"}) + await asyncio.to_thread( + with_message_history.with_config( + {"configurable": {"session_id": "session-123"}} + ).invoke, + {"question": "meow"}, + ) assert store == { "session-123": [ HumanMessage(content="hello"), @@ -1995,8 +2001,9 @@ EXPECTED_EVENTS = [ ] -async def test_sync_in_async_stream_lambdas() -> None: +async def test_sync_in_async_stream_lambdas(blockbuster: BlockBuster) -> None: """Test invoking nested runnable lambda.""" + blockbuster.deactivate() def add_one(x: int) -> int: return x + 1 @@ -2085,8 +2092,8 @@ class StreamingRunnable(Runnable[Input, Output]): **kwargs: Optional[Any], ) -> AsyncIterator[Output]: config = ensure_config(config) - callback_manager = get_callback_manager_for_config(config) - run_manager = callback_manager.on_chain_start( + callback_manager = get_async_callback_manager_for_config(config) + run_manager = await callback_manager.on_chain_start( None, input, name=config.get("run_name", self.get_name()), @@ -2109,9 +2116,9 @@ class StreamingRunnable(Runnable[Input, Output]): final_output = element # set final channel values as run output - run_manager.on_chain_end(final_output) + await run_manager.on_chain_end(final_output) except BaseException as e: - run_manager.on_chain_error(e) + await run_manager.on_chain_error(e) raise diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 3409d04f234..77922a5616c 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import json import sys import uuid -from collections.abc import AsyncGenerator, Generator -from typing import Any +from collections.abc import AsyncGenerator, Coroutine, Generator +from inspect import isasyncgenfunction +from typing import Any, Callable, Optional from unittest.mock import MagicMock, patch import pytest @@ -12,6 +15,7 @@ from langsmith.run_trees import RunTree from langsmith.utils import get_env_var from typing_extensions import Literal +from langchain_core.callbacks import BaseCallbackHandler from langchain_core.runnables.base import RunnableLambda, RunnableParallel from langchain_core.tracers.langchain import LangChainTracer @@ -35,6 +39,17 @@ def _get_posts(client: Client) -> list: return posts +def _create_tracer_with_mocked_client( + project_name: Optional[str] = None, + tags: Optional[list[str]] = None, +) -> LangChainTracer: + mock_session = MagicMock() + mock_client_ = Client( + session=mock_session, api_key="test", auto_batch_tracing=False + ) + return LangChainTracer(client=mock_client_, project_name=project_name, tags=tags) + + def test_tracing_context() -> None: mock_session = MagicMock() mock_client_ = Client( @@ -56,12 +71,8 @@ def test_tracing_context() -> None: def test_config_traceable_handoff() -> None: get_env_var.cache_clear() - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer( - client=mock_client_, project_name="another-flippin-project", tags=["such-a-tag"] + tracer = _create_tracer_with_mocked_client( + project_name="another-flippin-project", tags=["such-a-tag"] ) @traceable @@ -100,7 +111,7 @@ def test_config_traceable_handoff() -> None: my_parent_runnable = RunnableLambda(my_parent_function) assert my_parent_runnable.invoke(1, {"callbacks": [tracer]}) == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) assert all(post["session_name"] == "another-flippin-project" for post in posts) # There should have been 6 runs created, # one for each function invocation @@ -143,11 +154,7 @@ def test_config_traceable_handoff() -> None: sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" ) async def test_config_traceable_async_handoff() -> None: - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False - ) - tracer = LangChainTracer(client=mock_client_) + tracer = _create_tracer_with_mocked_client() @traceable def my_great_great_grandchild_function(a: int) -> int: @@ -175,7 +182,7 @@ async def test_config_traceable_async_handoff() -> None: my_parent_runnable = RunnableLambda(my_parent_function) # type: ignore result = await my_parent_runnable.ainvoke(1, {"callbacks": [tracer]}) assert result == 6 - posts = _get_posts(mock_client_) + posts = _get_posts(tracer.client) # There should have been 6 runs created, # one for each function invocation assert len(posts) == 6 @@ -245,144 +252,172 @@ def test_tracing_enable_disable( assert not mock_posts -@pytest.mark.parametrize( - "method", ["invoke", "stream", "batch", "ainvoke", "astream", "abatch"] -) -async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None: - if method.startswith("a") and sys.version_info < (3, 11): - pytest.skip("Asyncio context vars require Python 3.11+") - mock_session = MagicMock() - mock_client_ = Client( - session=mock_session, api_key="test", auto_batch_tracing=False +class TestRunnableSequenceParallelTraceNesting: + @pytest.fixture(autouse=True) + def _setup(self) -> None: + self.tracer = _create_tracer_with_mocked_client() + + @staticmethod + def _create_parent( + other_thing: Callable[ + [int], Generator[int, None, None] | AsyncGenerator[int, None] + ], + ) -> RunnableLambda: + @RunnableLambda + def my_child_function(a: int) -> int: + return a + 2 + + parallel = RunnableParallel( + chain_result=my_child_function.with_config(tags=["atag"]), + other_thing=other_thing, + ) + + def before(x: int) -> int: + return x + + def after(x: dict) -> int: + return x["chain_result"] + + sequence = before | parallel | after + if isasyncgenfunction(other_thing): + + @RunnableLambda # type: ignore + async def parent(a: int) -> int: + return await sequence.ainvoke(a) + + else: + + @RunnableLambda + def parent(a: int) -> int: + return sequence.invoke(a) + + return parent + + def _check_posts(self) -> None: + posts = _get_posts(self.tracer.client) + name_order = [ + "parent", + "RunnableSequence", + "before", + "RunnableParallel", + ["my_child_function", "other_thing"], + "after", + ] + expected_parents = { + "parent": None, + "RunnableSequence": "parent", + "before": "RunnableSequence", + "RunnableParallel": "RunnableSequence", + "my_child_function": "RunnableParallel", + "other_thing": "RunnableParallel", + "after": "RunnableSequence", + } + assert len(posts) == sum( + 1 if isinstance(n, str) else len(n) for n in name_order + ) + prev_dotted_order = None + dotted_order_map = {} + id_map = {} + parent_id_map = {} + i = 0 + for name in name_order: + if isinstance(name, list): + for n in name: + matching_post = next( + p for p in posts[i : i + len(name)] if p["name"] == n + ) + assert matching_post + dotted_order = matching_post["dotted_order"] + if prev_dotted_order is not None: + assert dotted_order > prev_dotted_order + dotted_order_map[n] = dotted_order + id_map[n] = matching_post["id"] + parent_id_map[n] = matching_post.get("parent_run_id") + i += len(name) + continue + else: + assert posts[i]["name"] == name + dotted_order = posts[i]["dotted_order"] + if prev_dotted_order is not None and not str( + expected_parents[name] + ).startswith("RunnableParallel"): + assert dotted_order > prev_dotted_order, ( + f"{name} not after {name_order[i - 1]}" + ) + prev_dotted_order = dotted_order + if name in dotted_order_map: + msg = f"Duplicate name {name}" + raise ValueError(msg) + dotted_order_map[name] = dotted_order + id_map[name] = posts[i]["id"] + parent_id_map[name] = posts[i].get("parent_run_id") + i += 1 + + # Now check the dotted orders + for name, parent_ in expected_parents.items(): + dotted_order = dotted_order_map[name] + if parent_ is not None: + parent_dotted_order = dotted_order_map[parent_] + assert dotted_order.startswith(parent_dotted_order), ( + f"{name}, {parent_dotted_order} not in {dotted_order}" + ) + assert str(parent_id_map[name]) == str(id_map[parent_]) + else: + assert dotted_order.split(".")[0] == dotted_order + + @pytest.mark.parametrize( + "method", + [ + lambda parent, cb: parent.invoke(1, {"callbacks": cb}), + lambda parent, cb: list(parent.stream(1, {"callbacks": cb}))[-1], + lambda parent, cb: parent.batch([1], {"callbacks": cb})[0], + ], + ids=["invoke", "stream", "batch"], ) - tracer = LangChainTracer(client=mock_client_) - - @RunnableLambda - def my_child_function(a: int) -> int: - return a + 2 - - if method.startswith("a"): - - async def other_thing(a: int) -> AsyncGenerator[int, None]: - yield 1 - - else: - + def test_sync( + self, method: Callable[[RunnableLambda, list[BaseCallbackHandler]], int] + ) -> None: def other_thing(a: int) -> Generator[int, None, None]: # type: ignore yield 1 - parallel = RunnableParallel( - chain_result=my_child_function.with_config(tags=["atag"]), - other_thing=other_thing, + parent = self._create_parent(other_thing) + + # Now run the chain and check the resulting posts + assert method(parent, [self.tracer]) == 3 + + self._check_posts() + + @staticmethod + async def ainvoke(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return await parent.ainvoke(1, {"callbacks": cb}) + + @staticmethod + async def astream(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return [res async for res in parent.astream(1, {"callbacks": cb})][-1] + + @staticmethod + async def abatch(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + return (await parent.abatch([1], {"callbacks": cb}))[0] + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+" ) + @pytest.mark.parametrize("method", [ainvoke, astream, abatch]) + async def test_async( + self, + method: Callable[ + [RunnableLambda, list[BaseCallbackHandler]], Coroutine[Any, Any, int] + ], + ) -> None: + async def other_thing(a: int) -> AsyncGenerator[int, None]: + yield 1 - def before(x: int) -> int: - return x + parent = self._create_parent(other_thing) - def after(x: dict) -> int: - return x["chain_result"] + # Now run the chain and check the resulting posts + assert await method(parent, [self.tracer]) == 3 - sequence = before | parallel | after - if method.startswith("a"): - - @RunnableLambda # type: ignore - async def parent(a: int) -> int: - return await sequence.ainvoke(a) - - else: - - @RunnableLambda - def parent(a: int) -> int: - return sequence.invoke(a) - - # Now run the chain and check the resulting posts - cb = [tracer] - if method == "invoke": - res: Any = parent.invoke(1, {"callbacks": cb}) # type: ignore - elif method == "ainvoke": - res = await parent.ainvoke(1, {"callbacks": cb}) # type: ignore - elif method == "stream": - results = list(parent.stream(1, {"callbacks": cb})) # type: ignore - res = results[-1] - elif method == "astream": - results = [res async for res in parent.astream(1, {"callbacks": cb})] # type: ignore - res = results[-1] - elif method == "batch": - res = parent.batch([1], {"callbacks": cb})[0] # type: ignore - elif method == "abatch": - res = (await parent.abatch([1], {"callbacks": cb}))[0] # type: ignore - else: - msg = f"Unknown method {method}" - raise ValueError(msg) - assert res == 3 - posts = _get_posts(mock_client_) - name_order = [ - "parent", - "RunnableSequence", - "before", - "RunnableParallel", - ["my_child_function", "other_thing"], - "after", - ] - expected_parents = { - "parent": None, - "RunnableSequence": "parent", - "before": "RunnableSequence", - "RunnableParallel": "RunnableSequence", - "my_child_function": "RunnableParallel", - "other_thing": "RunnableParallel", - "after": "RunnableSequence", - } - assert len(posts) == sum(1 if isinstance(n, str) else len(n) for n in name_order) - prev_dotted_order = None - dotted_order_map = {} - id_map = {} - parent_id_map = {} - i = 0 - for name in name_order: - if isinstance(name, list): - for n in name: - matching_post = next( - p for p in posts[i : i + len(name)] if p["name"] == n - ) - assert matching_post - dotted_order = matching_post["dotted_order"] - if prev_dotted_order is not None: - assert dotted_order > prev_dotted_order - dotted_order_map[n] = dotted_order - id_map[n] = matching_post["id"] - parent_id_map[n] = matching_post.get("parent_run_id") - i += len(name) - continue - else: - assert posts[i]["name"] == name - dotted_order = posts[i]["dotted_order"] - if prev_dotted_order is not None and not str( - expected_parents[name] - ).startswith("RunnableParallel"): - assert dotted_order > prev_dotted_order, ( - f"{name} not after {name_order[i - 1]}" - ) - prev_dotted_order = dotted_order - if name in dotted_order_map: - msg = f"Duplicate name {name}" - raise ValueError(msg) - dotted_order_map[name] = dotted_order - id_map[name] = posts[i]["id"] - parent_id_map[name] = posts[i].get("parent_run_id") - i += 1 - - # Now check the dotted orders - for name, parent_ in expected_parents.items(): - dotted_order = dotted_order_map[name] - if parent_ is not None: - parent_dotted_order = dotted_order_map[parent_] - assert dotted_order.startswith(parent_dotted_order), ( - f"{name}, {parent_dotted_order} not in {dotted_order}" - ) - assert str(parent_id_map[name]) == str(id_map[parent_]) - else: - assert dotted_order.split(".")[0] == dotted_order + self._check_posts() @pytest.mark.parametrize("parent_type", ("ls", "lc")) diff --git a/libs/core/tests/unit_tests/test_setup.py b/libs/core/tests/unit_tests/test_setup.py new file mode 100644 index 00000000000..1df3c73a252 --- /dev/null +++ b/libs/core/tests/unit_tests/test_setup.py @@ -0,0 +1,15 @@ +import time + +import pytest +from blockbuster import BlockingError + +from langchain_core import sys_info + + +async def test_blockbuster_setup() -> None: + """Check if blockbuster is correctly setup.""" + # Blocking call outside of langchain_core is allowed. + time.sleep(0.01) # noqa: ASYNC251 + with pytest.raises(BlockingError): + # Blocking call from langchain_core raises BlockingError. + sys_info.print_sys_info() diff --git a/libs/core/tests/unit_tests/tracers/test_memory_stream.py b/libs/core/tests/unit_tests/tracers/test_memory_stream.py index 74541f42adb..0371d7ddcdb 100644 --- a/libs/core/tests/unit_tests/tracers/test_memory_stream.py +++ b/libs/core/tests/unit_tests/tracers/test_memory_stream.py @@ -2,7 +2,6 @@ import asyncio import math import time from collections.abc import AsyncIterator -from concurrent.futures import ThreadPoolExecutor from langchain_core.tracers.memory_stream import _MemoryStream @@ -70,7 +69,7 @@ async def test_queue_for_streaming_via_sync_call() -> None: """Produce items with slight delay.""" tic = time.time() for i in range(3): - await asyncio.sleep(0.10) + await asyncio.sleep(0.2) toc = time.time() await writer.send( { @@ -93,9 +92,11 @@ async def test_queue_for_streaming_via_sync_call() -> None: **item, } - with ThreadPoolExecutor() as executor: - executor.submit(sync_call) - items = [item async for item in consumer()] + task = asyncio.create_task(asyncio.to_thread(sync_call)) + items = [item async for item in consumer()] + await task + + assert len(items) == 3 for item in items: delta_time = item["receive_time"] - item["produce_time"] @@ -107,7 +108,7 @@ async def test_queue_for_streaming_via_sync_call() -> None: # To verify that the producer and consumer are running in parallel, we # expect the delta_time to be smaller than the sleep delay in the producer # * # of items = 30 ms - assert math.isclose(delta_time, 0, abs_tol=0.010) is True, ( + assert math.isclose(delta_time, 0, abs_tol=0.020) is True, ( f"delta_time: {delta_time}" ) diff --git a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py index 67ede1c1508..5ebaf8633d9 100644 --- a/libs/core/tests/unit_tests/vectorstores/test_in_memory.py +++ b/libs/core/tests/unit_tests/vectorstores/test_in_memory.py @@ -83,17 +83,17 @@ async def test_inmemory_mmr() -> None: assert output[1] == _any_id_document(page_content="fou") -async def test_inmemory_dump_load(tmp_path: Path) -> None: +def test_inmemory_dump_load(tmp_path: Path) -> None: """Test end to end construction and search.""" embedding = DeterministicFakeEmbedding(size=6) - store = await InMemoryVectorStore.afrom_texts(["foo", "bar", "baz"], embedding) - output = await store.asimilarity_search("foo", k=1) + store = InMemoryVectorStore.from_texts(["foo", "bar", "baz"], embedding) + output = store.similarity_search("foo", k=1) test_file = str(tmp_path / "test.json") store.dump(test_file) loaded_store = InMemoryVectorStore.load(test_file, embedding) - loaded_output = await loaded_store.asimilarity_search("foo", k=1) + loaded_output = loaded_store.similarity_search("foo", k=1) assert output == loaded_output