mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 14:05:37 +00:00
fix(openai): ainvoke uses async _aget_response
; add async tests (#32459)
This commit is contained in:
committed by
GitHub
parent
2fed177d0b
commit
0abf82a45a
@@ -375,6 +375,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
try:
|
try:
|
||||||
|
# Use sync response handler in sync invoke
|
||||||
response = self._get_response(run)
|
response = self._get_response(run)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, metadata=run.dict())
|
run_manager.on_chain_error(e, metadata=run.dict())
|
||||||
@@ -511,7 +512,8 @@ class OpenAIAssistantRunnable(RunnableSerializable[dict, OutputType]):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise
|
raise
|
||||||
try:
|
try:
|
||||||
response = self._get_response(run)
|
# Use async response handler in async ainvoke
|
||||||
|
response = await self._aget_response(run)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
run_manager.on_chain_error(e, metadata=run.dict())
|
run_manager.on_chain_error(e, metadata=run.dict())
|
||||||
raise
|
raise
|
||||||
|
@@ -17,7 +17,7 @@ def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:
|
|||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_user_supplied_client() -> None:
|
def test_user_supplied_client() -> None:
|
||||||
import openai
|
openai = pytest.importorskip("openai")
|
||||||
|
|
||||||
client = openai.AzureOpenAI(
|
client = openai.AzureOpenAI(
|
||||||
azure_endpoint="azure_endpoint",
|
azure_endpoint="azure_endpoint",
|
||||||
@@ -48,6 +48,85 @@ def test_create_assistant() -> None:
|
|||||||
assert isinstance(assistant, OpenAIAssistantRunnable)
|
assert isinstance(assistant, OpenAIAssistantRunnable)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
@patch(
|
||||||
|
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
||||||
|
new=partial(_create_mock_client, use_async=True),
|
||||||
|
)
|
||||||
|
async def test_ainvoke_uses_async_response_completed() -> None:
|
||||||
|
# Arrange a runner with mocked async client and a completed run
|
||||||
|
assistant = OpenAIAssistantRunnable(
|
||||||
|
assistant_id="assistant_id",
|
||||||
|
client=_create_mock_client(),
|
||||||
|
async_client=_create_mock_client(use_async=True),
|
||||||
|
as_agent=False,
|
||||||
|
)
|
||||||
|
mock_run = MagicMock()
|
||||||
|
mock_run.id = "run-id"
|
||||||
|
mock_run.thread_id = "thread-id"
|
||||||
|
mock_run.status = "completed"
|
||||||
|
|
||||||
|
# await_for_run returns a completed run
|
||||||
|
await_for_run_mock = AsyncMock(return_value=mock_run)
|
||||||
|
# async messages list returns messages belonging to run
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.run_id = "run-id"
|
||||||
|
msg.content = []
|
||||||
|
list_mock = AsyncMock(return_value=[msg])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(assistant, "_await_for_run", await_for_run_mock),
|
||||||
|
patch.object(
|
||||||
|
assistant.async_client.beta.threads.messages,
|
||||||
|
"list",
|
||||||
|
list_mock,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Act
|
||||||
|
result = await assistant.ainvoke({"content": "hi"})
|
||||||
|
|
||||||
|
# Assert: returns messages list (non-agent path) and did not block
|
||||||
|
assert isinstance(result, list)
|
||||||
|
list_mock.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
@patch(
|
||||||
|
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
||||||
|
new=partial(_create_mock_client, use_async=True),
|
||||||
|
)
|
||||||
|
async def test_ainvoke_uses_async_response_requires_action_agent() -> None:
|
||||||
|
# Arrange a runner with mocked async client and requires_action run
|
||||||
|
assistant = OpenAIAssistantRunnable(
|
||||||
|
assistant_id="assistant_id",
|
||||||
|
client=_create_mock_client(),
|
||||||
|
async_client=_create_mock_client(use_async=True),
|
||||||
|
as_agent=True,
|
||||||
|
)
|
||||||
|
mock_run = MagicMock()
|
||||||
|
mock_run.id = "run-id"
|
||||||
|
mock_run.thread_id = "thread-id"
|
||||||
|
mock_run.status = "requires_action"
|
||||||
|
|
||||||
|
# Fake tool call structure
|
||||||
|
tool_call = MagicMock()
|
||||||
|
tool_call.id = "tool-id"
|
||||||
|
tool_call.function.name = "foo"
|
||||||
|
tool_call.function.arguments = '{\n "x": 1\n}'
|
||||||
|
mock_run.required_action.submit_tool_outputs.tool_calls = [tool_call]
|
||||||
|
|
||||||
|
await_for_run_mock = AsyncMock(return_value=mock_run)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
with patch.object(assistant, "_await_for_run", await_for_run_mock):
|
||||||
|
result = await assistant.ainvoke({"content": "hi"})
|
||||||
|
|
||||||
|
# Assert: returns list of OpenAIAssistantAction
|
||||||
|
assert isinstance(result, list)
|
||||||
|
assert result
|
||||||
|
assert getattr(result[0], "tool", None) == "foo"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
@patch(
|
@patch(
|
||||||
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
||||||
|
Reference in New Issue
Block a user