diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index 37ef8c554d5..c0c7138466d 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -375,6 +375,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): run_manager.on_chain_error(e) raise try: + # Use sync response handler in sync invoke response = self._get_response(run) except BaseException as e: run_manager.on_chain_error(e, metadata=run.dict()) @@ -511,7 +512,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]): run_manager.on_chain_error(e) raise try: - response = self._get_response(run) + # Use async response handler in async ainvoke + response = await self._aget_response(run) except BaseException as e: run_manager.on_chain_error(e, metadata=run.dict()) raise diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py index 8304295ac91..61a79e8a02d 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -17,7 +17,7 @@ def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any: @pytest.mark.requires("openai") def test_user_supplied_client() -> None: - import openai + openai = pytest.importorskip("openai") client = openai.AzureOpenAI( azure_endpoint="azure_endpoint", @@ -48,6 +48,85 @@ def test_create_assistant() -> None: assert isinstance(assistant, OpenAIAssistantRunnable) +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_async_client", + new=partial(_create_mock_client, use_async=True), +) +async def test_ainvoke_uses_async_response_completed() -> None: + # Arrange a runner with mocked async client and a completed run + assistant = OpenAIAssistantRunnable( + assistant_id="assistant_id", + client=_create_mock_client(), + async_client=_create_mock_client(use_async=True), + as_agent=False, + ) + mock_run = MagicMock() + mock_run.id = "run-id" + mock_run.thread_id = "thread-id" + mock_run.status = "completed" + + # await_for_run returns a completed run + await_for_run_mock = AsyncMock(return_value=mock_run) + # async messages list returns messages belonging to run + msg = MagicMock() + msg.run_id = "run-id" + msg.content = [] + list_mock = AsyncMock(return_value=[msg]) + + with ( + patch.object(assistant, "_await_for_run", await_for_run_mock), + patch.object( + assistant.async_client.beta.threads.messages, + "list", + list_mock, + ), + ): + # Act + result = await assistant.ainvoke({"content": "hi"}) + + # Assert: returns messages list (non-agent path) and did not block + assert isinstance(result, list) + list_mock.assert_awaited() + + +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_async_client", + new=partial(_create_mock_client, use_async=True), +) +async def test_ainvoke_uses_async_response_requires_action_agent() -> None: + # Arrange a runner with mocked async client and requires_action run + assistant = OpenAIAssistantRunnable( + assistant_id="assistant_id", + client=_create_mock_client(), + async_client=_create_mock_client(use_async=True), + as_agent=True, + ) + mock_run = MagicMock() + mock_run.id = "run-id" + mock_run.thread_id = "thread-id" + mock_run.status = "requires_action" + + # Fake tool call structure + tool_call = MagicMock() + tool_call.id = "tool-id" + tool_call.function.name = "foo" + tool_call.function.arguments = '{\n "x": 1\n}' + mock_run.required_action.submit_tool_outputs.tool_calls = [tool_call] + + await_for_run_mock = AsyncMock(return_value=mock_run) + + # Act + with patch.object(assistant, "_await_for_run", await_for_run_mock): + result = await assistant.ainvoke({"content": "hi"}) + + # Assert: returns list of OpenAIAssistantAction + assert isinstance(result, list) + assert result + assert getattr(result[0], "tool", None) == "foo" + + @pytest.mark.requires("openai") @patch( "langchain.agents.openai_assistant.base._get_openai_async_client",