diff --git a/libs/langchain/langchain/agents/openai_assistant/base.py b/libs/langchain/langchain/agents/openai_assistant/base.py index e7cb100ae50..76195fec352 100644 --- a/libs/langchain/langchain/agents/openai_assistant/base.py +++ b/libs/langchain/langchain/agents/openai_assistant/base.py @@ -3,12 +3,23 @@ from __future__ import annotations import json from json import JSONDecodeError from time import sleep -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from langchain_core.agents import AgentAction, AgentFinish from langchain_core.callbacks import CallbackManager from langchain_core.load import dumpd -from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool @@ -76,6 +87,32 @@ def _get_openai_async_client() -> openai.AsyncOpenAI: ) from e +def _is_assistants_builtin_tool( + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], +) -> bool: + """Determine if tool corresponds to OpenAI Assistants built-in.""" + assistants_builtin_tools = ("code_interpreter", "retrieval") + return ( + isinstance(tool, dict) + and ("type" in tool) + and (tool["type"] in assistants_builtin_tools) + ) + + +def _get_assistants_tool( + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], +) -> Dict[str, Any]: + """Convert a raw function/class to an OpenAI tool. + + Note that OpenAI assistants supports several built-in tools, + such as "code_interpreter" and "retrieval." + """ + if _is_assistants_builtin_tool(tool): + return tool # type: ignore + else: + return convert_to_openai_tool(tool) + + OutputType = Union[ List[OpenAIAssistantAction], OpenAIAssistantFinish, @@ -210,7 +247,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): assistant = client.beta.assistants.create( name=name, instructions=instructions, - tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore + tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore model=model, file_ids=kwargs.get("file_ids"), ) @@ -328,7 +365,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]): AsyncOpenAIAssistantRunnable configured to run using the created assistant. """ async_client = async_client or _get_openai_async_client() - openai_tools = [convert_to_openai_tool(tool) for tool in tools] + openai_tools = [_get_assistants_tool(tool) for tool in tools] assistant = await async_client.beta.assistants.create( name=name, instructions=instructions, diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py index aaa4ba48d1d..45fcea4ad2a 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -1,8 +1,20 @@ +from functools import partial +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + import pytest from langchain.agents.openai_assistant import OpenAIAssistantRunnable +def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any: + client = AsyncMock() if use_async else MagicMock() + mock_assistant = MagicMock() + mock_assistant.id = "abc123" + client.beta.assistants.create.return_value = mock_assistant # type: ignore + return client + + @pytest.mark.requires("openai") def test_user_supplied_client() -> None: import openai @@ -19,3 +31,34 @@ def test_user_supplied_client() -> None: ) assert assistant.client == client + + +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_client", + new=partial(_create_mock_client, use_async=False), +) +def test_create_assistant() -> None: + assistant = OpenAIAssistantRunnable.create_assistant( + name="name", + instructions="instructions", + tools=[{"type": "code_interpreter"}], + model="", + ) + assert isinstance(assistant, OpenAIAssistantRunnable) + + +@pytest.mark.requires("openai") +@patch( + "langchain.agents.openai_assistant.base._get_openai_async_client", + new=partial(_create_mock_client, use_async=True), +) +async def test_acreate_assistant() -> None: + assistant = await OpenAIAssistantRunnable.acreate_assistant( + name="name", + instructions="instructions", + tools=[{"type": "code_interpreter"}], + model="", + client=_create_mock_client(), + ) + assert isinstance(assistant, OpenAIAssistantRunnable)