mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
Harrison/tiktoken spec (#964)
Co-authored-by: James Briggs <35938317+jamescalam@users.noreply.github.com> Co-authored-by: Harrison Chase <harrisonchase@Harrisons-MBP.attlocal.net>
This commit is contained in:
parent
5f8082bdd7
commit
ba54d36787
@ -3,7 +3,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Iterable, List, Optional
|
from typing import (
|
||||||
|
AbstractSet,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Collection,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
@ -114,7 +124,11 @@ class TextSplitter(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_tiktoken_encoder(
|
def from_tiktoken_encoder(
|
||||||
cls, encoding_name: str = "gpt2", **kwargs: Any
|
cls,
|
||||||
|
encoding_name: str = "gpt2",
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
**kwargs: Any,
|
||||||
) -> TextSplitter:
|
) -> TextSplitter:
|
||||||
"""Text splitter that uses tiktoken encoder to count length."""
|
"""Text splitter that uses tiktoken encoder to count length."""
|
||||||
try:
|
try:
|
||||||
@ -125,11 +139,19 @@ class TextSplitter(ABC):
|
|||||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||||
"Please it install it with `pip install tiktoken`."
|
"Please it install it with `pip install tiktoken`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# create a GPT-3 encoder instance
|
# create a GPT-3 encoder instance
|
||||||
enc = tiktoken.get_encoding(encoding_name)
|
enc = tiktoken.get_encoding(encoding_name)
|
||||||
|
|
||||||
def _tiktoken_encoder(text: str) -> int:
|
def _tiktoken_encoder(text: str, **kwargs: Any) -> int:
|
||||||
return len(enc.encode(text))
|
return len(
|
||||||
|
enc.encode(
|
||||||
|
text,
|
||||||
|
allowed_special=allowed_special,
|
||||||
|
disallowed_special=disallowed_special,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return cls(length_function=_tiktoken_encoder, **kwargs)
|
return cls(length_function=_tiktoken_encoder, **kwargs)
|
||||||
|
|
||||||
@ -155,7 +177,13 @@ class CharacterTextSplitter(TextSplitter):
|
|||||||
class TokenTextSplitter(TextSplitter):
|
class TokenTextSplitter(TextSplitter):
|
||||||
"""Implementation of splitting text that looks at tokens."""
|
"""Implementation of splitting text that looks at tokens."""
|
||||||
|
|
||||||
def __init__(self, encoding_name: str = "gpt2", **kwargs: Any):
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoding_name: str = "gpt2",
|
||||||
|
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||||
|
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
"""Create a new TextSplitter."""
|
"""Create a new TextSplitter."""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
try:
|
try:
|
||||||
@ -168,11 +196,17 @@ class TokenTextSplitter(TextSplitter):
|
|||||||
)
|
)
|
||||||
# create a GPT-3 encoder instance
|
# create a GPT-3 encoder instance
|
||||||
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
self._tokenizer = tiktoken.get_encoding(encoding_name)
|
||||||
|
self._allowed_special = allowed_special
|
||||||
|
self._disallowed_special = disallowed_special
|
||||||
|
|
||||||
def split_text(self, text: str) -> List[str]:
|
def split_text(self, text: str) -> List[str]:
|
||||||
"""Split incoming text and return chunks."""
|
"""Split incoming text and return chunks."""
|
||||||
splits = []
|
splits = []
|
||||||
input_ids = self._tokenizer.encode(text)
|
input_ids = self._tokenizer.encode(
|
||||||
|
text,
|
||||||
|
allowed_special=self._allowed_special,
|
||||||
|
disallowed_special=self._disallowed_special,
|
||||||
|
)
|
||||||
start_idx = 0
|
start_idx = 0
|
||||||
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
cur_idx = min(start_idx + self._chunk_size, len(input_ids))
|
||||||
chunk_ids = input_ids[start_idx:cur_idx]
|
chunk_ids = input_ids[start_idx:cur_idx]
|
||||||
|
Loading…
Reference in New Issue
Block a user