openai[patch]: fix get_num_tokens for function calls (#25785)

Closes https://github.com/langchain-ai/langchain/issues/25784

See additional discussion
[here](0a4ee864e9 (r145147380)).
This commit is contained in:
ccurme 2024-08-27 16:18:19 -04:00 committed by GitHub
parent 2aa35d80a0
commit 2e5c379632
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 3 deletions

View File

@ -947,7 +947,7 @@ class BaseChatOpenAI(BaseChatModel):
else: else:
# Cast str(value) in case the message value is not a string # Cast str(value) in case the message value is not a string
# This occurs with function messages # This occurs with function messages
num_tokens += len(encoding.encode(value)) num_tokens += len(encoding.encode(str(value)))
if key == "name": if key == "name":
num_tokens += tokens_per_name num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant # every reply is primed with <im_start>assistant

View File

@ -677,7 +677,10 @@ def test_get_num_tokens_from_messages() -> None:
AIMessage( AIMessage(
"", "",
additional_kwargs={ additional_kwargs={
"function_call": json.dumps({"arguments": "old", "name": "fun"}) "function_call": {
"arguments": json.dumps({"arg1": "arg1"}),
"name": "fun",
}
}, },
), ),
AIMessage( AIMessage(
@ -688,6 +691,6 @@ def test_get_num_tokens_from_messages() -> None:
), ),
ToolMessage("foobar", tool_call_id="foo"), ToolMessage("foobar", tool_call_id="foo"),
] ]
expected = 170 expected = 176
actual = llm.get_num_tokens_from_messages(messages) actual = llm.get_num_tokens_from_messages(messages)
assert expected == actual assert expected == actual