mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
huggingface tokenizer (#75)
This commit is contained in:
parent
b542941234
commit
d87e73ddb1
104
examples/huggingface_tokenizer_text_splitter.ipynb
Normal file
104
examples/huggingface_tokenizer_text_splitter.ipynb
Normal file
@ -0,0 +1,104 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "e82c4685",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.text_splitter import HuggingFaceTokenizerSplitter"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a8ce51d5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import GPT2TokenizerFast\n",
|
||||
"\n",
|
||||
"tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ca5e72c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('state_of_the_union.txt') as f:\n",
|
||||
" state_of_the_union = f.read()\n",
|
||||
"text_splitter = HuggingFaceTokenizerSplitter(tokenizer, chunk_size=1000, chunk_overlap=0)\n",
|
||||
"texts = text_splitter.split_text(state_of_the_union)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "37cdfbeb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n",
|
||||
"\n",
|
||||
"Last year COVID-19 kept us apart. This year we are finally together again. \n",
|
||||
"\n",
|
||||
"Tonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n",
|
||||
"\n",
|
||||
"With a duty to one another to the American people to the Constitution. \n",
|
||||
"\n",
|
||||
"And with an unwavering resolve that freedom will always triumph over tyranny. \n",
|
||||
"\n",
|
||||
"Six days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n",
|
||||
"\n",
|
||||
"He thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n",
|
||||
"\n",
|
||||
"He met the Ukrainian people. \n",
|
||||
"\n",
|
||||
"From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n",
|
||||
"\n",
|
||||
"Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(texts[0])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d214aec2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,12 +1,18 @@
|
||||
"""Functionality for splitting text."""
|
||||
from abc import abstractmethod
|
||||
from typing import Iterable, List
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List
|
||||
|
||||
|
||||
class TextSplitter:
|
||||
class TextSplitter(ABC):
|
||||
"""Interface for splitting text into chunks."""
|
||||
|
||||
def __init__(self, separator: str, chunk_size: int, chunk_overlap: int):
|
||||
def __init__(
|
||||
self,
|
||||
separator: str = "\n\n",
|
||||
chunk_size: int = 4000,
|
||||
chunk_overlap: int = 200,
|
||||
length_function: Callable[[str], int] = len,
|
||||
):
|
||||
"""Create a new TextSplitter."""
|
||||
if chunk_overlap > chunk_size:
|
||||
raise ValueError(
|
||||
@ -16,6 +22,7 @@ class TextSplitter:
|
||||
self._separator = separator
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_overlap = chunk_overlap
|
||||
self._length_function = length_function
|
||||
|
||||
@abstractmethod
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
@ -28,29 +35,43 @@ class TextSplitter:
|
||||
current_doc: List[str] = []
|
||||
total = 0
|
||||
for d in splits:
|
||||
if total > self._chunk_size:
|
||||
if total >= self._chunk_size:
|
||||
docs.append(self._separator.join(current_doc))
|
||||
while total > self._chunk_overlap:
|
||||
total -= len(current_doc[0])
|
||||
total -= self._length_function(current_doc[0])
|
||||
current_doc = current_doc[1:]
|
||||
current_doc.append(d)
|
||||
total += len(d)
|
||||
total += self._length_function(d)
|
||||
docs.append(self._separator.join(current_doc))
|
||||
return docs
|
||||
|
||||
@classmethod
|
||||
def from_huggingface_tokenizer(
|
||||
cls, tokenizer: Any, **kwargs: Any
|
||||
) -> "TextSplitter":
|
||||
"""Text splitter than uses HuggingFace tokenizer to count length."""
|
||||
try:
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
if not isinstance(tokenizer, PreTrainedTokenizerBase):
|
||||
raise ValueError(
|
||||
"Tokenizer received was not an instance of PreTrainedTokenizerBase"
|
||||
)
|
||||
|
||||
def _huggingface_tokenizer_length(text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import transformers python package. "
|
||||
"Please it install it with `pip install transformers`."
|
||||
)
|
||||
return cls(length_function=_huggingface_tokenizer_length, **kwargs)
|
||||
|
||||
|
||||
class CharacterTextSplitter(TextSplitter):
|
||||
"""Implementation of splitting text that looks at characters."""
|
||||
|
||||
def __init__(
|
||||
self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200
|
||||
):
|
||||
"""Create a new CharacterTextSplitter."""
|
||||
super(CharacterTextSplitter, self).__init__(
|
||||
separator, chunk_size, chunk_overlap
|
||||
)
|
||||
self._separator = separator
|
||||
|
||||
def split_text(self, text: str) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
# First we naively split the large input into a bunch of smaller ones.
|
||||
|
@ -10,6 +10,7 @@ wikipedia
|
||||
huggingface_hub
|
||||
faiss-cpu
|
||||
sentence_transformers
|
||||
transformers
|
||||
manifest-ml
|
||||
spacy
|
||||
nltk
|
||||
|
23
tests/integration_tests/test_text_splitter.py
Normal file
23
tests/integration_tests/test_text_splitter.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Test text splitters that require an integration."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
|
||||
|
||||
def test_huggingface_type_check() -> None:
|
||||
"""Test that type checks are done properly on input."""
|
||||
with pytest.raises(ValueError):
|
||||
CharacterTextSplitter.from_huggingface_tokenizer("foo")
|
||||
|
||||
|
||||
def test_huggingface_tokenizer() -> None:
|
||||
"""Test text splitter that uses a HuggingFace tokenizer."""
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
|
||||
tokenizer, separator=" ", chunk_size=1, chunk_overlap=0
|
||||
)
|
||||
output = text_splitter.split_text("foo bar")
|
||||
assert output == ["foo", "bar"]
|
Loading…
Reference in New Issue
Block a user