community: add truncation params when an openai assistant's run is created (#28158)

**Description:** When an OpenAI assistant is invoked, it creates a run
by default, allowing users to set only a few request fields. The
truncation strategy is set to auto, which includes previous messages in
the thread along with the current question until the context length is
reached. This causes token usage to grow incrementally:
consumed_tokens = previous_consumed_tokens + current_consumed_tokens.

This PR adds support for user-defined truncation strategies, giving
better control over token consumption.

**Issue:** High token consumption.
This commit is contained in:
LuisMSotamba 2024-11-27 10:53:53 -05:00 committed by GitHub
parent c09000f20e
commit 0901f11b0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 49 additions and 5 deletions

View File

@ -543,11 +543,16 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
Returns:
Any: The created run object.
"""
params = {
k: v
for k, v in input.items()
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
}
allowed_assistant_params = (
"instructions",
"model",
"tools",
"tool_resources",
"run_metadata",
"truncation_strategy",
"max_prompt_tokens",
)
params = {k: v for k, v in input.items() if k in allowed_assistant_params}
return self.client.beta.threads.runs.create(
input["thread_id"],
assistant_id=self.assistant_id,

View File

@ -0,0 +1,39 @@
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from langchain_community.agents.openai_assistant import OpenAIAssistantV2Runnable
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
client = AsyncMock() if use_async else MagicMock()
client.beta.threads.runs.create = MagicMock(return_value=None) # type: ignore
return client
@pytest.mark.requires("openai")
def test_set_run_truncation_params() -> None:
client = _create_mock_client()
assistant = OpenAIAssistantV2Runnable(assistant_id="assistant_xyz", client=client)
input = {
"content": "AI question",
"thread_id": "thread_xyz",
"instructions": "You're a helpful assistant; answer questions as best you can.",
"model": "gpt-4o",
"max_prompt_tokens": 2000,
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
}
expected_response = {
"assistant_id": "assistant_xyz",
"instructions": "You're a helpful assistant; answer questions as best you can.",
"model": "gpt-4o",
"max_prompt_tokens": 2000,
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
}
assistant._create_run(input=input)
_, kwargs = client.beta.threads.runs.create.call_args
assert kwargs == expected_response