openai[patch]: fix tool calling token counting (#23408)

Resolves https://github.com/langchain-ai/langchain/issues/23388
This commit is contained in:
Bagatur
2024-06-25 10:34:25 -07:00
committed by GitHub
parent 22fa32e164
commit 9d145b9630
2 changed files with 51 additions and 0 deletions

View File

@@ -726,6 +726,7 @@ class BaseChatOpenAI(BaseChatModel):
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
# TODO: Count bound tools as part of input.
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
@@ -760,7 +761,13 @@ class BaseChatOpenAI(BaseChatModel):
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# This is an inferred approximation. OpenAI does not document how to
# count tool message tokens.
if key == "tool_call_id":
num_tokens += 3
continue
if isinstance(value, list):
# content or tool calls
for val in value:
if isinstance(val, str) or val["type"] == "text":
text = val["text"] if isinstance(val, dict) else val
@@ -773,10 +780,19 @@ class BaseChatOpenAI(BaseChatModel):
if not image_size:
continue
num_tokens += _count_image_tokens(*image_size)
# Tool/function call token counting is not documented by OpenAI.
# This is an approximation.
elif val["type"] == "function":
num_tokens += len(
encoding.encode(val["function"]["arguments"])
)
num_tokens += len(encoding.encode(val["function"]["name"]))
else:
raise ValueError(
f"Unrecognized content block type\n\n{val}"
)
elif not value:
continue
else:
# Cast str(value) in case the message value is not a string
# This occurs with function messages