core, openai: support custom token encoders (#20762)

This commit is contained in:
ccurme
2024-04-23 09:57:05 -04:00
committed by GitHub
parent b481b73805
commit 7a922f3e48
4 changed files with 21 additions and 1 deletions

View File

@@ -139,6 +139,7 @@ class Llamafile(LLM):
"streaming",
"tags",
"verbose",
"custom_get_token_ids",
]
attrs = [
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys

View File

@@ -5,6 +5,7 @@ from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Mapping,
@@ -97,6 +98,10 @@ class BaseLanguageModel(
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
default=None, exclude=True
)
"""Optional encoder to use for counting tokens."""
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
@@ -310,7 +315,10 @@ class BaseLanguageModel(
A list of ids corresponding to the tokens in the text, in order they occur
in the text.
"""
return _get_token_ids_default_method(text)
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
else:
return _get_token_ids_default_method(text)
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text.

View File

@@ -521,6 +521,8 @@ class BaseOpenAI(BaseLLM):
def get_token_ids(self, text: str) -> List[int]:
"""Get the token IDs using the tiktoken package."""
if self.custom_get_token_ids is not None:
return self.custom_get_token_ids(text)
# tiktoken NOT supported for Python < 3.8
if sys.version_info[1] < 8:
return super().get_num_tokens(text)

View File

@@ -1,4 +1,5 @@
import os
from typing import List
import pytest
@@ -54,3 +55,11 @@ def mock_completion() -> dict:
def test_get_token_ids(model: str) -> None:
OpenAI(model=model).get_token_ids("foo")
return
def test_custom_token_counting() -> None:
def token_encoder(text: str) -> List[int]:
return [1, 2, 3]
llm = OpenAI(custom_get_token_ids=token_encoder)
assert llm.get_token_ids("foo") == [1, 2, 3]