mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
community: add truncation params when an openai assistant's run is created (#28158)
**Description:** When an OpenAI assistant is invoked, it creates a run by default, allowing users to set only a few request fields. The truncation strategy is set to auto, which includes previous messages in the thread along with the current question until the context length is reached. This causes token usage to grow incrementally: consumed_tokens = previous_consumed_tokens + current_consumed_tokens. This PR adds support for user-defined truncation strategies, giving better control over token consumption. **Issue:** High token consumption.
This commit is contained in:
parent
c09000f20e
commit
0901f11b0f
@ -543,11 +543,16 @@ class OpenAIAssistantV2Runnable(OpenAIAssistantRunnable):
|
||||
Returns:
|
||||
Any: The created run object.
|
||||
"""
|
||||
params = {
|
||||
k: v
|
||||
for k, v in input.items()
|
||||
if k in ("instructions", "model", "tools", "tool_resources", "run_metadata")
|
||||
}
|
||||
allowed_assistant_params = (
|
||||
"instructions",
|
||||
"model",
|
||||
"tools",
|
||||
"tool_resources",
|
||||
"run_metadata",
|
||||
"truncation_strategy",
|
||||
"max_prompt_tokens",
|
||||
)
|
||||
params = {k: v for k, v in input.items() if k in allowed_assistant_params}
|
||||
return self.client.beta.threads.runs.create(
|
||||
input["thread_id"],
|
||||
assistant_id=self.assistant_id,
|
||||
|
@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.agents.openai_assistant import OpenAIAssistantV2Runnable
|
||||
|
||||
|
||||
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
|
||||
client = AsyncMock() if use_async else MagicMock()
|
||||
client.beta.threads.runs.create = MagicMock(return_value=None) # type: ignore
|
||||
return client
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_set_run_truncation_params() -> None:
|
||||
client = _create_mock_client()
|
||||
|
||||
assistant = OpenAIAssistantV2Runnable(assistant_id="assistant_xyz", client=client)
|
||||
input = {
|
||||
"content": "AI question",
|
||||
"thread_id": "thread_xyz",
|
||||
"instructions": "You're a helpful assistant; answer questions as best you can.",
|
||||
"model": "gpt-4o",
|
||||
"max_prompt_tokens": 2000,
|
||||
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
|
||||
}
|
||||
expected_response = {
|
||||
"assistant_id": "assistant_xyz",
|
||||
"instructions": "You're a helpful assistant; answer questions as best you can.",
|
||||
"model": "gpt-4o",
|
||||
"max_prompt_tokens": 2000,
|
||||
"truncation_strategy": {"type": "last_messages", "last_messages": 10},
|
||||
}
|
||||
|
||||
assistant._create_run(input=input)
|
||||
_, kwargs = client.beta.threads.runs.create.call_args
|
||||
|
||||
assert kwargs == expected_response
|
Loading…
Reference in New Issue
Block a user