mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-17 02:03:44 +00:00
allow tokentextsplitters to use model name to select encoder (#2963)
Fixes a bug I was seeing when the `TokenTextSplitter` was correctly splitting text under the gpt3.5-turbo token limit, but when firing the prompt off too openai, it'd come back with an error that we were over the context limit. gpt3.5-turbo and gpt-4 use `cl100k_base` tokenizer, and so the counts are just always off with the default `gpt-2` encoder. It's possible to pass along the encoding to the `TokenTextSplitter`, but it's much simpler to pass the model name of the LLM. No more concern about keeping the tokenizer and llm model in sync :)
This commit is contained in:
parent
706ebd8f9c
commit
51894ddd98
@ -139,6 +139,7 @@ class TextSplitter(ABC):
|
|||||||
def from_tiktoken_encoder(
|
def from_tiktoken_encoder(
|
||||||
cls,
|
cls,
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
|
model_name: Optional[str] = None,
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -153,8 +154,10 @@ class TextSplitter(ABC):
|
|||||||
"Please install it with `pip install tiktoken`."
|
"Please install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# create a GPT-3 encoder instance
|
if model_name is not None:
|
||||||
enc = tiktoken.get_encoding(encoding_name)
|
enc = tiktoken.encoding_for_model(model_name)
|
||||||
|
else:
|
||||||
|
enc = tiktoken.get_encoding(encoding_name)
|
||||||
|
|
||||||
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
||||||
return len(
|
return len(
|
||||||
@ -193,6 +196,7 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoding_name: str = "gpt2",
|
encoding_name: str = "gpt2",
|
||||||
|
model_name: Optional[str] = None,
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@ -207,8 +211,12 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
"This is needed in order to for TokenTextSplitter. "
|
"This is needed in order to for TokenTextSplitter. "
|
||||||
"Please install it with `pip install tiktoken`."
|
"Please install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
# create a GPT-3 encoder instance
|
|
||||||
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
if model_name is not None:
|
||||||
|
enc = tiktoken.encoding_for_model(model_name)
|
||||||
|
else:
|
||||||
|
enc = tiktoken.get_encoding(encoding_name)
|
||||||
|
self._tokenizer = enc
|
||||||
self._allowed_special = allowed_special
|
self._allowed_special = allowed_special
|
||||||
self._disallowed_special = disallowed_special
|
self._disallowed_special = disallowed_special
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user