mistral, openai: support custom tokenizers in chat models (#20901)

This commit is contained in:
ccurme
2024-04-25 15:23:29 -04:00
committed by GitHub
parent 6986e44959
commit fdabd3cdf5
3 changed files with 20 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
"""Test MistralAI Chat API wrapper."""
import os
from typing import Any, AsyncGenerator, Dict, Generator, cast
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
from unittest.mock import patch
import pytest
@@ -190,3 +190,11 @@ def test__convert_dict_to_message_tool_call() -> None:
)
assert result == expected_output
assert _convert_message_to_mistral_chat_message(expected_output) == message
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]

View File

@@ -703,6 +703,8 @@ class ChatOpenAI(BaseChatModel):
def get_token_ids(self, text: str) -> List[int]:
"""Get the tokens present in the text with tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python 3.7 or below
if sys.version_info[1] <= 7:
return super().get_token_ids(text)

View File

@@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any
from typing import Any, List
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -279,3 +279,11 @@ def test_openai_invoke_name(mock_completion: dict) -> None:
# check return type has name
assert res.content == "Bar Baz"
assert res.name == "Erick"
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]
llm = ChatOpenAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]