mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 20:15:40 +00:00
openai[patch]: fix tool calling token counting (#23408)
Resolves https://github.com/langchain-ai/langchain/issues/23388
This commit is contained in:
parent
22fa32e164
commit
9d145b9630
@ -726,6 +726,7 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
_, encoding_model = self._get_encoding_model()
|
||||
return encoding_model.encode(text)
|
||||
|
||||
# TODO: Count bound tools as part of input.
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
@ -760,7 +761,13 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# This is an inferred approximation. OpenAI does not document how to
|
||||
# count tool message tokens.
|
||||
if key == "tool_call_id":
|
||||
num_tokens += 3
|
||||
continue
|
||||
if isinstance(value, list):
|
||||
# content or tool calls
|
||||
for val in value:
|
||||
if isinstance(val, str) or val["type"] == "text":
|
||||
text = val["text"] if isinstance(val, dict) else val
|
||||
@ -773,10 +780,19 @@ class BaseChatOpenAI(BaseChatModel):
|
||||
if not image_size:
|
||||
continue
|
||||
num_tokens += _count_image_tokens(*image_size)
|
||||
# Tool/function call token counting is not documented by OpenAI.
|
||||
# This is an approximation.
|
||||
elif val["type"] == "function":
|
||||
num_tokens += len(
|
||||
encoding.encode(val["function"]["arguments"])
|
||||
)
|
||||
num_tokens += len(encoding.encode(val["function"]["name"]))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognized content block type\n\n{val}"
|
||||
)
|
||||
elif not value:
|
||||
continue
|
||||
else:
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
|
@ -334,3 +334,38 @@ def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None:
|
||||
"""Test passing in manually construct tool call message."""
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
|
||||
llm.with_structured_output(schema)
|
||||
|
||||
|
||||
def test_get_num_tokens_from_messages() -> None:
|
||||
llm = ChatOpenAI(model="gpt-4o")
|
||||
messages = [
|
||||
SystemMessage("you're a good assistant"),
|
||||
HumanMessage("how are you"),
|
||||
HumanMessage(
|
||||
[
|
||||
{"type": "text", "text": "what's in this image"},
|
||||
{"type": "image_url", "image_url": {"url": "https://foobar.com"}},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "https://foobar.com", "detail": "low"},
|
||||
},
|
||||
]
|
||||
),
|
||||
AIMessage("a nice bird"),
|
||||
AIMessage(
|
||||
"", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||
),
|
||||
AIMessage(
|
||||
"",
|
||||
additional_kwargs={
|
||||
"function_call": json.dumps({"arguments": "old", "name": "fun"})
|
||||
},
|
||||
),
|
||||
AIMessage(
|
||||
"text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})]
|
||||
),
|
||||
ToolMessage("foobar", tool_call_id="foo"),
|
||||
]
|
||||
expected = 170
|
||||
actual = llm.get_num_tokens_from_messages(messages)
|
||||
assert expected == actual
|
||||
|
Loading…
Reference in New Issue
Block a user