mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-08 20:41:52 +00:00
openai[patch]: fix tool calling token counting (#23408)
Resolves https://github.com/langchain-ai/langchain/issues/23388
This commit is contained in:
parent
22fa32e164
commit
9d145b9630
@ -726,6 +726,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
_, encoding_model = self._get_encoding_model()
|
_, encoding_model = self._get_encoding_model()
|
||||||
return encoding_model.encode(text)
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
|
# TODO: Count bound tools as part of input.
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
@ -760,7 +761,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
for message in messages_dict:
|
for message in messages_dict:
|
||||||
num_tokens += tokens_per_message
|
num_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
for key, value in message.items():
|
||||||
|
# This is an inferred approximation. OpenAI does not document how to
|
||||||
|
# count tool message tokens.
|
||||||
|
if key == "tool_call_id":
|
||||||
|
num_tokens += 3
|
||||||
|
continue
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
|
# content or tool calls
|
||||||
for val in value:
|
for val in value:
|
||||||
if isinstance(val, str) or val["type"] == "text":
|
if isinstance(val, str) or val["type"] == "text":
|
||||||
text = val["text"] if isinstance(val, dict) else val
|
text = val["text"] if isinstance(val, dict) else val
|
||||||
@ -773,10 +780,19 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
if not image_size:
|
if not image_size:
|
||||||
continue
|
continue
|
||||||
num_tokens += _count_image_tokens(*image_size)
|
num_tokens += _count_image_tokens(*image_size)
|
||||||
|
# Tool/function call token counting is not documented by OpenAI.
|
||||||
|
# This is an approximation.
|
||||||
|
elif val["type"] == "function":
|
||||||
|
num_tokens += len(
|
||||||
|
encoding.encode(val["function"]["arguments"])
|
||||||
|
)
|
||||||
|
num_tokens += len(encoding.encode(val["function"]["name"]))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unrecognized content block type\n\n{val}"
|
f"Unrecognized content block type\n\n{val}"
|
||||||
)
|
)
|
||||||
|
elif not value:
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
# Cast str(value) in case the message value is not a string
|
# Cast str(value) in case the message value is not a string
|
||||||
# This occurs with function messages
|
# This occurs with function messages
|
||||||
|
@ -334,3 +334,38 @@ def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
|
|||||||
"""Test passing in manually construct tool call message."""
|
"""Test passing in manually construct tool call message."""
|
||||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||||
llm.with_structured_output(schema)
|
llm.with_structured_output(schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens_from_messages() -> None:
|
||||||
|
llm = ChatOpenAI(model="gpt-4o")
|
||||||
|
messages = [
|
||||||
|
SystemMessage("you're a good assistant"),
|
||||||
|
HumanMessage("how are you"),
|
||||||
|
HumanMessage(
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "what's in this image"},
|
||||||
|
{"type": "image_url", "image_url": {"url": "https://foobar.com"}},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": "https://foobar.com", "detail": "low"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
AIMessage("a nice bird"),
|
||||||
|
AIMessage(
|
||||||
|
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||||
|
),
|
||||||
|
AIMessage(
|
||||||
|
"",
|
||||||
|
additional_kwargs={
|
||||||
|
"function_call": json.dumps({"arguments": "old", "name": "fun"})
|
||||||
|
},
|
||||||
|
),
|
||||||
|
AIMessage(
|
||||||
|
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||||
|
),
|
||||||
|
ToolMessage("foobar", tool_call_id="foo"),
|
||||||
|
]
|
||||||
|
expected = 170
|
||||||
|
actual = llm.get_num_tokens_from_messages(messages)
|
||||||
|
assert expected == actual
|
||||||
|
Loading…
Reference in New Issue
Block a user