mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 03:59:25 +00:00
Harrison/tiktoken spec (#964)
Co-authored-by: James Briggs <35938317+jamescalam@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
parent
5f8082bdd7
commit
ba54d36787
@ -3,7 +3,17 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List, Optional
|
||||
from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
@ -114,7 +124,11 @@ class TextSplitter(ABC):
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls, encoding_name: str = "gpt2", **kwargs: Any
|
||||
cls,
|
||||
encoding_name: str = "gpt2",
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TextSplitter:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
@ -125,11 +139,19 @@ class TextSplitter(ABC):
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please it install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
# create a GPT-3 encoder instance
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str) -> int:
|
||||
return len(enc.encode(text))
|
||||
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
||||
return len(
|
||||
enc.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||
|
||||
@ -155,7 +177,13 @@ class CharacterTextSplitter(TextSplitter):
|
||||
class TokenTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at tokens."""
|
||||
|
||||
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
|
||||
def __init__(
|
||||
self,
|
||||
encoding_name: str = "gpt2",
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
@ -168,11 +196,17 @@ class TokenTextSplitter(TextSplitter):
|
||||
)
|
||||
# create a GPT-3 encoder instance
|
||||
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
||||
self._allowed_special = allowed_special
|
||||
self._disallowed_special = disallowed_special
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
splits = []
|
||||
input_ids = self._tokenizer.encode(text)
|
||||
input_ids = self._tokenizer.encode(
|
||||
text,
|
||||
allowed_special=self._allowed_special,
|
||||
disallowed_special=self._disallowed_special,
|
||||
)
|
||||
start_idx = 0
|
||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||
chunk_ids = input_ids[start_idx:cur_idx]
|
||||
|
Loading…
Reference in New Issue
Block a user