mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 19:12:42 +00:00
core, openai: support custom token encoders (#20762)
This commit is contained in:
@@ -139,6 +139,7 @@ class Llamafile(LLM):
|
|||||||
"streaming",
|
"streaming",
|
||||||
"tags",
|
"tags",
|
||||||
"verbose",
|
"verbose",
|
||||||
|
"custom_get_token_ids",
|
||||||
]
|
]
|
||||||
attrs = [
|
attrs = [
|
||||||
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
|
k for k in get_pydantic_field_names(self.__class__) if k not in ignore_keys
|
||||||
|
@@ -5,6 +5,7 @@ from functools import lru_cache
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
@@ -97,6 +98,10 @@ class BaseLanguageModel(
|
|||||||
"""Tags to add to the run trace."""
|
"""Tags to add to the run trace."""
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
|
||||||
"""Metadata to add to the run trace."""
|
"""Metadata to add to the run trace."""
|
||||||
|
custom_get_token_ids: Optional[Callable[[str], List[int]]] = Field(
|
||||||
|
default=None, exclude=True
|
||||||
|
)
|
||||||
|
"""Optional encoder to use for counting tokens."""
|
||||||
|
|
||||||
@validator("verbose", pre=True, always=True)
|
@validator("verbose", pre=True, always=True)
|
||||||
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
def set_verbose(cls, verbose: Optional[bool]) -> bool:
|
||||||
@@ -310,7 +315,10 @@ class BaseLanguageModel(
|
|||||||
A list of ids corresponding to the tokens in the text, in order they occur
|
A list of ids corresponding to the tokens in the text, in order they occur
|
||||||
in the text.
|
in the text.
|
||||||
"""
|
"""
|
||||||
return _get_token_ids_default_method(text)
|
if self.custom_get_token_ids is not None:
|
||||||
|
return self.custom_get_token_ids(text)
|
||||||
|
else:
|
||||||
|
return _get_token_ids_default_method(text)
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Get the number of tokens present in the text.
|
"""Get the number of tokens present in the text.
|
||||||
|
@@ -521,6 +521,8 @@ class BaseOpenAI(BaseLLM):
|
|||||||
|
|
||||||
def get_token_ids(self, text: str) -> List[int]:
|
def get_token_ids(self, text: str) -> List[int]:
|
||||||
"""Get the token IDs using the tiktoken package."""
|
"""Get the token IDs using the tiktoken package."""
|
||||||
|
if self.custom_get_token_ids is not None:
|
||||||
|
return self.custom_get_token_ids(text)
|
||||||
# tiktoken NOT supported for Python < 3.8
|
# tiktoken NOT supported for Python < 3.8
|
||||||
if sys.version_info[1] < 8:
|
if sys.version_info[1] < 8:
|
||||||
return super().get_num_tokens(text)
|
return super().get_num_tokens(text)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -54,3 +55,11 @@ def mock_completion() -> dict:
|
|||||||
def test_get_token_ids(model: str) -> None:
|
def test_get_token_ids(model: str) -> None:
|
||||||
OpenAI(model=model).get_token_ids("foo")
|
OpenAI(model=model).get_token_ids("foo")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_token_counting() -> None:
|
||||||
|
def token_encoder(text: str) -> List[int]:
|
||||||
|
return [1, 2, 3]
|
||||||
|
|
||||||
|
llm = OpenAI(custom_get_token_ids=token_encoder)
|
||||||
|
assert llm.get_token_ids("foo") == [1, 2, 3]
|
||||||
|
Reference in New Issue
Block a user