diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index bfa293b446e..5631eb1190d 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -726,6 +726,7 @@ class BaseChatOpenAI(BaseChatModel): _, encoding_model = self._get_encoding_model() return encoding_model.encode(text) + # TODO: Count bound tools as part of input. def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. @@ -760,7 +761,13 @@ class BaseChatOpenAI(BaseChatModel): for message in messages_dict: num_tokens += tokens_per_message for key, value in message.items(): + # This is an inferred approximation. OpenAI does not document how to + # count tool message tokens. + if key == "tool_call_id": + num_tokens += 3 + continue if isinstance(value, list): + # content or tool calls for val in value: if isinstance(val, str) or val["type"] == "text": text = val["text"] if isinstance(val, dict) else val @@ -773,10 +780,19 @@ class BaseChatOpenAI(BaseChatModel): if not image_size: continue num_tokens += _count_image_tokens(*image_size) + # Tool/function call token counting is not documented by OpenAI. + # This is an approximation. + elif val["type"] == "function": + num_tokens += len( + encoding.encode(val["function"]["arguments"]) + ) + num_tokens += len(encoding.encode(val["function"]["name"])) else: raise ValueError( f"Unrecognized content block type\n\n{val}" ) + elif not value: + continue else: # Cast str(value) in case the message value is not a string # This occurs with function messages diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 1250e02e0a9..f523aacc2c5 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -334,3 +334,38 @@ def test_with_structured_output(schema: Union[Type[BaseModel], dict]) -> None: """Test passing in manually construct tool call message.""" llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm.with_structured_output(schema) + + +def test_get_num_tokens_from_messages() -> None: + llm = ChatOpenAI(model="gpt-4o") + messages = [ + SystemMessage("you're a good assistant"), + HumanMessage("how are you"), + HumanMessage( + [ + {"type": "text", "text": "what's in this image"}, + {"type": "image_url", "image_url": {"url": "https://foobar.com"}}, + { + "type": "image_url", + "image_url": {"url": "https://foobar.com", "detail": "low"}, + }, + ] + ), + AIMessage("a nice bird"), + AIMessage( + "", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] + ), + AIMessage( + "", + additional_kwargs={ + "function_call": json.dumps({"arguments": "old", "name": "fun"}) + }, + ), + AIMessage( + "text", tool_calls=[ToolCall(id="foo", name="bar", args={"arg1": "arg1"})] + ), + ToolMessage("foobar", tool_call_id="foo"), + ] + expected = 170 + actual = llm.get_num_tokens_from_messages(messages) + assert expected == actual