mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 15:35:14 +00:00
mistral, openai: support custom tokenizers in chat models (#20901)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Test MistralAI Chat API wrapper."""
|
||||
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, cast
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -190,3 +190,11 @@ def test__convert_dict_to_message_tool_call() -> None:
|
||||
)
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_mistral_chat_message(expected_output) == message
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = ChatMistralAI(custom_get_token_ids=token_encoder)
|
||||
assert llm.get_token_ids("foo") == [1, 2, 3]
|
||||
|
@@ -703,6 +703,8 @@ class ChatOpenAI(BaseChatModel):
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Get the tokens present in the text with tiktoken package."""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
# tiktoken NOT supported for Python 3.7 or below
|
||||
if sys.version_info[1] <= 7:
|
||||
return super().get_token_ids(text)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Test OpenAI Chat API wrapper."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -279,3 +279,11 @@ def test_openai_invoke_name(mock_completion: dict) -> None:
|
||||
# check return type has name
|
||||
assert res.content == "Bar Baz"
|
||||
assert res.name == "Erick"
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = ChatOpenAI(custom_get_token_ids=token_encoder)
|
||||
assert llm.get_token_ids("foo") == [1, 2, 3]
|
||||
|
Reference in New Issue
Block a user