diff --git a/libs/langchain/langchain/llms/base.py b/libs/langchain/langchain/llms/base.py index 8ccaf0deb64..7da494de78b 100644 --- a/libs/langchain/langchain/llms/base.py +++ b/libs/langchain/langchain/llms/base.py @@ -91,7 +91,11 @@ def create_base_retry_decorator( if isinstance(run_manager, AsyncCallbackManagerForLLMRun): coro = run_manager.on_retry(retry_state) try: - asyncio.run(coro) + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(coro) + else: + asyncio.run(coro) except Exception as e: _log_error_once(f"Error in on_retry: {e}") else: diff --git a/libs/langchain/tests/integration_tests/llms/test_openai.py b/libs/langchain/tests/integration_tests/llms/test_openai.py index 6b584ae154b..ca8911078a4 100644 --- a/libs/langchain/tests/integration_tests/llms/test_openai.py +++ b/libs/langchain/tests/integration_tests/llms/test_openai.py @@ -1,7 +1,6 @@ """Test OpenAI API wrapper.""" from pathlib import Path -from typing import Any, Generator -from unittest.mock import MagicMock, patch +from typing import Generator import pytest @@ -11,7 +10,6 @@ from langchain.llms.loading import load_llm from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( - FakeAsyncCallbackHandler, FakeCallbackHandler, ) @@ -351,63 +349,3 @@ def mock_completion() -> dict: ], "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, } - - -@pytest.mark.requires("openai") -def test_openai_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.predict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 - - -@pytest.mark.requires("openai") -async def test_openai_async_retries(mock_completion: dict) -> None: - llm = OpenAI() - mock_client = MagicMock() - completed = False - raised = False - import openai - - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed, raised - if not raised: - raised = True - raise openai.error.APIError - completed = True - return mock_completion - - mock_client.create = raise_once - callback_handler = FakeAsyncCallbackHandler() - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.apredict("bar", callbacks=[callback_handler]) - assert res == "Bar Baz" - assert completed - assert raised - assert callback_handler.retries == 1 diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 87b56a9bff2..f4819c6930e 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -290,6 +290,13 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi """Whether to ignore agent callbacks.""" return self.ignore_agent_ + async def on_retry( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + self.on_retry_common() + async def on_llm_start( self, *args: Any, diff --git a/libs/langchain/tests/unit_tests/llms/test_openai.py b/libs/langchain/tests/unit_tests/llms/test_openai.py index cc0fc74c1f7..54750a95921 100644 --- a/libs/langchain/tests/unit_tests/llms/test_openai.py +++ b/libs/langchain/tests/unit_tests/llms/test_openai.py @@ -1,3 +1,4 @@ +import asyncio import os from typing import Any from unittest.mock import MagicMock, patch @@ -5,6 +6,10 @@ from unittest.mock import MagicMock, patch import pytest from langchain.llms.openai import OpenAI +from tests.unit_tests.callbacks.fake_callback_handler import ( + FakeAsyncCallbackHandler, + FakeCallbackHandler, +) os.environ["OPENAI_API_KEY"] = "foo" @@ -45,44 +50,62 @@ def mock_completion() -> dict: @pytest.mark.requires("openai") -def test_openai_calls(mock_completion: dict) -> None: +def test_openai_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False + raised = False + import openai def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError completed = True return mock_completion mock_client.create = raise_once + callback_handler = FakeCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.predict("bar") + res = llm.predict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1 @pytest.mark.requires("openai") +@pytest.mark.asyncio async def test_openai_async_retries(mock_completion: dict) -> None: llm = OpenAI() mock_client = MagicMock() completed = False + raised = False + import openai - def raise_once(*args: Any, **kwargs: Any) -> Any: - nonlocal completed + async def araise_once(*args: Any, **kwargs: Any) -> Any: + nonlocal completed, raised + if not raised: + raised = True + raise openai.error.APIError + await asyncio.sleep(0) completed = True return mock_completion - mock_client.create = raise_once + mock_client.acreate = araise_once + callback_handler = FakeAsyncCallbackHandler() with patch.object( llm, "client", mock_client, ): - res = llm.apredict("bar") + res = await llm.apredict("bar", callbacks=[callback_handler]) assert res == "Bar Baz" assert completed + assert raised + assert callback_handler.retries == 1