openai[patch]: fix tool calling token counting (#23408)

Resolves https://github.com/langchain-ai/langchain/issues/23388
This commit is contained in:
Bagatur 2024-06-25 10:34:25 -07:00 committed by GitHub
parent 22fa32e164
commit 9d145b9630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 51 additions and 0 deletions

View File

@ -726,6 +726,7 @@ class BaseChatOpenAI(BaseChatModel):
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
# TODO: Count bound tools as part of input.
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@ -760,7 +761,13 @@ class BaseChatOpenAI(BaseChatModel):
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# This is an inferred approximation. OpenAI does not document how to
# count tool message tokens.
if key == "tool_call_id":
num_tokens += 3
continue
if isinstance(value, list):
# content or tool calls
for val in value:
if isinstance(val, str) or val["type"] == "text":
text = val["text"] if isinstance(val, dict) else val
@ -773,10 +780,19 @@ class BaseChatOpenAI(BaseChatModel):
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
# Tool/function call token counting is not documented by OpenAI.
# This is an approximation.
elif val["type"] == "function":
num_tokens += len(
encoding.encode(val["function"]["arguments"])
)
num_tokens += len(encoding.encode(val["function"]["name"]))
else:
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
elif not value:
continue
else:
# Cast str(value) in case the message value is not a string
# This occurs with function messages

View File

@ -334,3 +334,38 @@ def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
"""Test passing in manually construct tool call message."""
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
llm.with_structured_output(schema)
def test_get_num_tokens_from_messages() -> None:
llm = ChatOpenAI(model="gpt-4o")
messages = [
SystemMessage("you're a good assistant"),
HumanMessage("how are you"),
HumanMessage(
[
{"type": "text", "text": "what's in this image"},
{"type": "image_url", "image_url": {"url": "https://foobar.com"}},
{
"type": "image_url",
"image_url": {"url": "https://foobar.com", "detail": "low"},
},
]
),
AIMessage("a nice bird"),
AIMessage(
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
),
AIMessage(
"",
additional_kwargs={
"function_call": json.dumps({"arguments": "old", "name": "fun"})
},
),
AIMessage(
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
),
ToolMessage("foobar", tool_call_id="foo"),
]
expected = 170
actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual