From 0901f11b0f7808bb6302d0208d4a6a5e411059ec Mon Sep 17 00:00:00 2001 From: LuisMSotamba <46969419+LuisMSotamba@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:53:53 -0500 Subject: [PATCH] community: add truncation params when an openai assistant's run is created (#28158) **Description:** When an OpenAI assistant is invoked, it creates a run by default, allowing users to set only a few request fields. The truncation strategy is set to auto, which includes previous messages in the thread along with the current question until the context length is reached. This causes token usage to grow incrementally: consumed_tokens = previous_consumed_tokens + current_consumed_tokens. This PR adds support for user-defined truncation strategies, giving better control over token consumption. **Issue:** High token consumption. --- .../agents/openai_assistant/base.py | 15 ++++--- .../agents/test_openai_assistant.py | 39 +++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 libs/community/tests/unit_tests/agents/test_openai_assistant.py diff --git a/libs/community/langchain_community/agents/openai_assistant/base.py b/libs/community/langchain_community/agents/openai_assistant/base.py index 971e8ba8381..f9006f66ca5 100644 --- a/libs/community/langchain_community/agents/openai_assistant/base.py +++ b/libs/community/langchain_community/agents/openai_assistant/base.py @@ -543,11 +543,16 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable): Returns: Any: The created run object. """ - params = { - k: v - for k, v in input.items() - if k in ("instructions", "model", "tools", "tool_resources", "run_metadata") - } + allowed_assistant_params = ( + "instructions", + "model", + "tools", + "tool_resources", + "run_metadata", + "truncation_strategy", + "max_prompt_tokens", + ) + params = {k: v for k, v in input.items() if k in allowed_assistant_params} return self.client.beta.threads.runs.create( input["thread_id"], assistant_id=self.assistant_id, diff --git a/libs/community/tests/unit_tests/agents/test_openai_assistant.py b/libs/community/tests/unit_tests/agents/test_openai_assistant.py new file mode 100644 index 00000000000..ea99ab5ccd7 --- /dev/null +++ b/libs/community/tests/unit_tests/agents/test_openai_assistant.py @@ -0,0 +1,39 @@ +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from langchain_community.agents.openai_assistant import OpenAIAssistantV2Runnable + + +def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any: + client = AsyncMock() if use_async else MagicMock() + client.beta.threads.runs.create = MagicMock(return_value=None) # type: ignore + return client + + +@pytest.mark.requires("openai") +def test_set_run_truncation_params() -> None: + client = _create_mock_client() + + assistant = OpenAIAssistantV2Runnable(assistant_id="assistant_xyz", client=client) + input = { + "content": "AI question", + "thread_id": "thread_xyz", + "instructions": "You're a helpful assistant; answer questions as best you can.", + "model": "gpt-4o", + "max_prompt_tokens": 2000, + "truncation_strategy": {"type": "last_messages", "last_messages": 10}, + } + expected_response = { + "assistant_id": "assistant_xyz", + "instructions": "You're a helpful assistant; answer questions as best you can.", + "model": "gpt-4o", + "max_prompt_tokens": 2000, + "truncation_strategy": {"type": "last_messages", "last_messages": 10}, + } + + assistant._create_run(input=input) + _, kwargs = client.beta.threads.runs.create.call_args + + assert kwargs == expected_response