mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
core, openai: support custom token encoders (#20762)
This commit is contained in:
@@ -139,6 +139,7 @@ class Llamafile(LLM):
|
||||
"streaming",
|
||||
"tags",
|
||||
"verbose",
|
||||
"custom_get_token_ids",
|
||||
]
|
||||
attrs = [
|
||||
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
|
||||
|
@@ -5,6 +5,7 @@ from functools import lru_cache
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
@@ -97,6 +98,10 @@ class BaseLanguageModel(
|
||||
"""Tags to add to the run trace."""
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||
"""Metadata to add to the run trace."""
|
||||
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
|
||||
default=None, exclude=True
|
||||
)
|
||||
"""Optional encoder to use for counting tokens."""
|
||||
|
||||
@validator("verbose", pre=True, always=True)
|
||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||
@@ -310,7 +315,10 @@ class BaseLanguageModel(
|
||||
A list of ids corresponding to the tokens in the text, in order they occur
|
||||
in the text.
|
||||
"""
|
||||
return _get_token_ids_default_method(text)
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
else:
|
||||
return _get_token_ids_default_method(text)
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Get the number of tokens present in the text.
|
||||
|
@@ -521,6 +521,8 @@ class BaseOpenAI(BaseLLM):
|
||||
|
||||
def get_token_ids(self, text: str) -> List[int]:
|
||||
"""Get the token IDs using the tiktoken package."""
|
||||
if self.custom_get_token_ids is not None:
|
||||
return self.custom_get_token_ids(text)
|
||||
# tiktoken NOT supported for Python < 3.8
|
||||
if sys.version_info[1] < 8:
|
||||
return super().get_num_tokens(text)
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -54,3 +55,11 @@ def mock_completion() -> dict:
|
||||
def test_get_token_ids(model: str) -> None:
|
||||
OpenAI(model=model).get_token_ids("foo")
|
||||
return
|
||||
|
||||
|
||||
def test_custom_token_counting() -> None:
|
||||
def token_encoder(text: str) -> List[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
llm = OpenAI(custom_get_token_ids=token_encoder)
|
||||
assert llm.get_token_ids("foo") == [1, 2, 3]
|
||||
|
Reference in New Issue
Block a user