mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 18:24:10 +00:00
[langchain] fix OpenAIAssistantRunnable.create_assistant (#19081)
- **Description:** OpenAI assistants support some pre-built tools (e.g., `"retrieval"` and `"code_interpreter"`) and expect these as `{"type": "code_interpreter"}`. This may have been upset by https://github.com/langchain-ai/langchain/pull/18935 - **Issue:** https://github.com/langchain-ai/langchain/issues/19057
This commit is contained in:
parent
b40c80007f
commit
8a2528c34a
@ -3,12 +3,23 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.agents import AgentAction, AgentFinish
|
from langchain_core.agents import AgentAction, AgentFinish
|
||||||
from langchain_core.callbacks import CallbackManager
|
from langchain_core.callbacks import CallbackManager
|
||||||
from langchain_core.load import dumpd
|
from langchain_core.load import dumpd
|
||||||
from langchain_core.pydantic_v1 import Field, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||||
@ -76,6 +87,32 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
|
|||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _is_assistants_builtin_tool(
|
||||||
|
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||||
|
) -> bool:
|
||||||
|
"""Determine if tool corresponds to OpenAI Assistants built-in."""
|
||||||
|
assistants_builtin_tools = ("code_interpreter", "retrieval")
|
||||||
|
return (
|
||||||
|
isinstance(tool, dict)
|
||||||
|
and ("type" in tool)
|
||||||
|
and (tool["type"] in assistants_builtin_tools)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_assistants_tool(
|
||||||
|
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Convert a raw function/class to an OpenAI tool.
|
||||||
|
|
||||||
|
Note that OpenAI assistants supports several built-in tools,
|
||||||
|
such as "code_interpreter" and "retrieval."
|
||||||
|
"""
|
||||||
|
if _is_assistants_builtin_tool(tool):
|
||||||
|
return tool # type: ignore
|
||||||
|
else:
|
||||||
|
return convert_to_openai_tool(tool)
|
||||||
|
|
||||||
|
|
||||||
OutputType = Union[
|
OutputType = Union[
|
||||||
List[OpenAIAssistantAction],
|
List[OpenAIAssistantAction],
|
||||||
OpenAIAssistantFinish,
|
OpenAIAssistantFinish,
|
||||||
@ -210,7 +247,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|||||||
assistant = client.beta.assistants.create(
|
assistant = client.beta.assistants.create(
|
||||||
name=name,
|
name=name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore
|
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
|
||||||
model=model,
|
model=model,
|
||||||
file_ids=kwargs.get("file_ids"),
|
file_ids=kwargs.get("file_ids"),
|
||||||
)
|
)
|
||||||
@ -328,7 +365,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
|||||||
AsyncOpenAIAssistantRunnable configured to run using the created assistant.
|
AsyncOpenAIAssistantRunnable configured to run using the created assistant.
|
||||||
"""
|
"""
|
||||||
async_client = async_client or _get_openai_async_client()
|
async_client = async_client or _get_openai_async_client()
|
||||||
openai_tools = [convert_to_openai_tool(tool) for tool in tools]
|
openai_tools = [_get_assistants_tool(tool) for tool in tools]
|
||||||
assistant = await async_client.beta.assistants.create(
|
assistant = await async_client.beta.assistants.create(
|
||||||
name=name,
|
name=name,
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
|
@ -1,8 +1,20 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
|
||||||
|
client = AsyncMock() if use_async else MagicMock()
|
||||||
|
mock_assistant = MagicMock()
|
||||||
|
mock_assistant.id = "abc123"
|
||||||
|
client.beta.assistants.create.return_value = mock_assistant # type: ignore
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_user_supplied_client() -> None:
|
def test_user_supplied_client() -> None:
|
||||||
import openai
|
import openai
|
||||||
@ -19,3 +31,34 @@ def test_user_supplied_client() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert assistant.client == client
|
assert assistant.client == client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
@patch(
|
||||||
|
"langchain.agents.openai_assistant.base._get_openai_client",
|
||||||
|
new=partial(_create_mock_client, use_async=False),
|
||||||
|
)
|
||||||
|
def test_create_assistant() -> None:
|
||||||
|
assistant = OpenAIAssistantRunnable.create_assistant(
|
||||||
|
name="name",
|
||||||
|
instructions="instructions",
|
||||||
|
tools=[{"type": "code_interpreter"}],
|
||||||
|
model="",
|
||||||
|
)
|
||||||
|
assert isinstance(assistant, OpenAIAssistantRunnable)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
@patch(
|
||||||
|
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
||||||
|
new=partial(_create_mock_client, use_async=True),
|
||||||
|
)
|
||||||
|
async def test_acreate_assistant() -> None:
|
||||||
|
assistant = await OpenAIAssistantRunnable.acreate_assistant(
|
||||||
|
name="name",
|
||||||
|
instructions="instructions",
|
||||||
|
tools=[{"type": "code_interpreter"}],
|
||||||
|
model="",
|
||||||
|
client=_create_mock_client(),
|
||||||
|
)
|
||||||
|
assert isinstance(assistant, OpenAIAssistantRunnable)
|
||||||
|
Loading…
Reference in New Issue
Block a user