diff --git a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py index ab70d02d45b..63713f22731 100644 --- a/libs/partners/mistralai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/mistralai/tests/unit_tests/test_chat_models.py @@ -1,7 +1,7 @@ """Test MistralAI Chat API wrapper.""" import os -from typing import Any, AsyncGenerator, Dict, Generator, cast +from typing import Any, AsyncGenerator, Dict, Generator, List, cast from unittest.mock import patch import pytest @@ -190,3 +190,11 @@ def test__convert_dict_to_message_tool_call() -> None: ) assert result == expected_output assert _convert_message_to_mistral_chat_message(expected_output) == message + + +def test_custom_token_counting() -> None: + def token_encoder(text: str) -> List[int]: + return [1, 2, 3] + + llm = ChatMistralAI(custom_get_token_ids=token_encoder) + assert llm.get_token_ids("foo") == [1, 2, 3] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index e68377c2a85..334aa96885e 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -703,6 +703,8 @@ class ChatOpenAI(BaseChatModel): def get_token_ids(self, text: str) -> List[int]: """Get the tokens present in the text with tiktoken package.""" + if self.custom_get_token_ids is not None: + return self.custom_get_token_ids(text) # tiktoken NOT supported for Python 3.7 or below if sys.version_info[1] <= 7: return super().get_token_ids(text) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 9665af8f644..e4cc2cfddf8 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,7 +1,7 @@ """Test OpenAI Chat API wrapper.""" import json -from typing import Any +from typing import Any, List from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -279,3 +279,11 @@ def test_openai_invoke_name(mock_completion: dict) -> None: # check return type has name assert res.content == "Bar Baz" assert res.name == "Erick" + + +def test_custom_token_counting() -> None: + def token_encoder(text: str) -> List[int]: + return [1, 2, 3] + + llm = ChatOpenAI(custom_get_token_ids=token_encoder) + assert llm.get_token_ids("foo") == [1, 2, 3]