mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-02 01:23:07 +00:00
[langchain] fix OpenAIAssistantRunnable.create_assistant (#19081)
- **Description:** OpenAI assistants support some pre-built tools (e.g., `"retrieval"` and `"code_interpreter"`) and expect these as `{"type": "code_interpreter"}`. This may have been upset by https://github.com/langchain-ai/langchain/pull/18935 - **Issue:** https://github.com/langchain-ai/langchain/issues/19057
This commit is contained in:
parent
b40c80007f
commit
8a2528c34a
@ -3,12 +3,23 @@ from __future__ import annotations
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from time import sleep
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.function_calling import convert_to_openai_tool
|
||||
@ -76,6 +87,32 @@ def _get_openai_async_client() -> openai.AsyncOpenAI:
|
||||
) from e
|
||||
|
||||
|
||||
def _is_assistants_builtin_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> bool:
|
||||
"""Determine if tool corresponds to OpenAI Assistants built-in."""
|
||||
assistants_builtin_tools = ("code_interpreter", "retrieval")
|
||||
return (
|
||||
isinstance(tool, dict)
|
||||
and ("type" in tool)
|
||||
and (tool["type"] in assistants_builtin_tools)
|
||||
)
|
||||
|
||||
|
||||
def _get_assistants_tool(
|
||||
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert a raw function/class to an OpenAI tool.
|
||||
|
||||
Note that OpenAI assistants supports several built-in tools,
|
||||
such as "code_interpreter" and "retrieval."
|
||||
"""
|
||||
if _is_assistants_builtin_tool(tool):
|
||||
return tool # type: ignore
|
||||
else:
|
||||
return convert_to_openai_tool(tool)
|
||||
|
||||
|
||||
OutputType = Union[
|
||||
List[OpenAIAssistantAction],
|
||||
OpenAIAssistantFinish,
|
||||
@ -210,7 +247,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
assistant = client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
tools=[convert_to_openai_tool(tool) for tool in tools], # type: ignore
|
||||
tools=[_get_assistants_tool(tool) for tool in tools], # type: ignore
|
||||
model=model,
|
||||
file_ids=kwargs.get("file_ids"),
|
||||
)
|
||||
@ -328,7 +365,7 @@ class OpenAIAssistantRunnable(RunnableSerializable[Dict, OutputType]):
|
||||
AsyncOpenAIAssistantRunnable configured to run using the created assistant.
|
||||
"""
|
||||
async_client = async_client or _get_openai_async_client()
|
||||
openai_tools = [convert_to_openai_tool(tool) for tool in tools]
|
||||
openai_tools = [_get_assistants_tool(tool) for tool in tools]
|
||||
assistant = await async_client.beta.assistants.create(
|
||||
name=name,
|
||||
instructions=instructions,
|
||||
|
@ -1,8 +1,20 @@
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
||||
|
||||
|
||||
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
|
||||
client = AsyncMock() if use_async else MagicMock()
|
||||
mock_assistant = MagicMock()
|
||||
mock_assistant.id = "abc123"
|
||||
client.beta.assistants.create.return_value = mock_assistant # type: ignore
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_user_supplied_client() -> None:
|
||||
import openai
|
||||
@ -19,3 +31,34 @@ def test_user_supplied_client() -> None:
|
||||
)
|
||||
|
||||
assert assistant.client == client
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@patch(
|
||||
"langchain.agents.openai_assistant.base._get_openai_client",
|
||||
new=partial(_create_mock_client, use_async=False),
|
||||
)
|
||||
def test_create_assistant() -> None:
|
||||
assistant = OpenAIAssistantRunnable.create_assistant(
|
||||
name="name",
|
||||
instructions="instructions",
|
||||
tools=[{"type": "code_interpreter"}],
|
||||
model="",
|
||||
)
|
||||
assert isinstance(assistant, OpenAIAssistantRunnable)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
@patch(
|
||||
"langchain.agents.openai_assistant.base._get_openai_async_client",
|
||||
new=partial(_create_mock_client, use_async=True),
|
||||
)
|
||||
async def test_acreate_assistant() -> None:
|
||||
assistant = await OpenAIAssistantRunnable.acreate_assistant(
|
||||
name="name",
|
||||
instructions="instructions",
|
||||
tools=[{"type": "code_interpreter"}],
|
||||
model="",
|
||||
client=_create_mock_client(),
|
||||
)
|
||||
assert isinstance(assistant, OpenAIAssistantRunnable)
|
||||
|
Loading…
Reference in New Issue
Block a user