diff --git a/examples/huggingface_tokenizer_text_splitter.ipynb b/examples/huggingface_tokenizer_text_splitter.ipynb new file mode 100644 index 00000000000..afd6db90d08 --- /dev/null +++ b/examples/huggingface_tokenizer_text_splitter.ipynb @@ -0,0 +1,104 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e82c4685", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.text_splitter import HuggingFaceTokenizerSplitter" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a8ce51d5", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import GPT2TokenizerFast\n", + "\n", + "tokenizer = GPT2TokenizerFast.from_pretrained(\"gpt2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ca5e72c0", + "metadata": {}, + "outputs": [], + "source": [ + "with open('state_of_the_union.txt') as f:\n", + " state_of_the_union = f.read()\n", + "text_splitter = HuggingFaceTokenizerSplitter(tokenizer, chunk_size=1000, chunk_overlap=0)\n", + "texts = text_splitter.split_text(state_of_the_union)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "37cdfbeb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n", + "\n", + "Last year COVID-19 kept us apart. This year we are finally together again. \n", + "\n", + "Tonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n", + "\n", + "With a duty to one another to the American people to the Constitution. \n", + "\n", + "And with an unwavering resolve that freedom will always triumph over tyranny. \n", + "\n", + "Six days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n", + "\n", + "He thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n", + "\n", + "He met the Ukrainian people. \n", + "\n", + "From President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n", + "\n", + "Groups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. \n" + ] + } + ], + "source": [ + "print(texts[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d214aec2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/text_splitter.py b/langchain/text_splitter.py index 1c44d98763f..dbae51cf4db 100644 --- a/langchain/text_splitter.py +++ b/langchain/text_splitter.py @@ -1,12 +1,18 @@ """Functionality for splitting text.""" -from abc import abstractmethod -from typing import Iterable, List +from abc import ABC, abstractmethod +from typing import Any, Callable, Iterable, List -class TextSplitter: +class TextSplitter(ABC): """Interface for splitting text into chunks.""" - def __init__(self, separator: str, chunk_size: int, chunk_overlap: int): + def __init__( + self, + separator: str = "\n\n", + chunk_size: int = 4000, + chunk_overlap: int = 200, + length_function: Callable[[str], int] = len, + ): """Create a new TextSplitter.""" if chunk_overlap > chunk_size: raise ValueError( @@ -16,6 +22,7 @@ class TextSplitter: self._separator = separator self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap + self._length_function = length_function @abstractmethod def split_text(self, text: str) -> List[str]: @@ -28,29 +35,43 @@ class TextSplitter: current_doc: List[str] = [] total = 0 for d in splits: - if total > self._chunk_size: + if total >= self._chunk_size: docs.append(self._separator.join(current_doc)) while total > self._chunk_overlap: - total -= len(current_doc[0]) + total -= self._length_function(current_doc[0]) current_doc = current_doc[1:] current_doc.append(d) - total += len(d) + total += self._length_function(d) docs.append(self._separator.join(current_doc)) return docs + @classmethod + def from_huggingface_tokenizer( + cls, tokenizer: Any, **kwargs: Any + ) -> "TextSplitter": + """Text splitter than uses HuggingFace tokenizer to count length.""" + try: + from transformers import PreTrainedTokenizerBase + + if not isinstance(tokenizer, PreTrainedTokenizerBase): + raise ValueError( + "Tokenizer received was not an instance of PreTrainedTokenizerBase" + ) + + def _huggingface_tokenizer_length(text: str) -> int: + return len(tokenizer.encode(text)) + + except ImportError: + raise ValueError( + "Could not import transformers python package. " + "Please it install it with `pip install transformers`." + ) + return cls(length_function=_huggingface_tokenizer_length, **kwargs) + class CharacterTextSplitter(TextSplitter): """Implementation of splitting text that looks at characters.""" - def __init__( - self, separator: str = "\n\n", chunk_size: int = 4000, chunk_overlap: int = 200 - ): - """Create a new CharacterTextSplitter.""" - super(CharacterTextSplitter, self).__init__( - separator, chunk_size, chunk_overlap - ) - self._separator = separator - def split_text(self, text: str) -> List[str]: """Split incoming text and return chunks.""" # First we naively split the large input into a bunch of smaller ones. diff --git a/requirements.txt b/requirements.txt index 317453c1e26..0678b5a5a71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ wikipedia huggingface_hub faiss-cpu sentence_transformers +transformers manifest-ml spacy nltk diff --git a/tests/integration_tests/test_text_splitter.py b/tests/integration_tests/test_text_splitter.py new file mode 100644 index 00000000000..902705c7261 --- /dev/null +++ b/tests/integration_tests/test_text_splitter.py @@ -0,0 +1,23 @@ +"""Test text splitters that require an integration.""" + +import pytest + +from langchain.text_splitter import CharacterTextSplitter + + +def test_huggingface_type_check() -> None: + """Test that type checks are done properly on input.""" + with pytest.raises(ValueError): + CharacterTextSplitter.from_huggingface_tokenizer("foo") + + +def test_huggingface_tokenizer() -> None: + """Test text splitter that uses a HuggingFace tokenizer.""" + from transformers import GPT2TokenizerFast + + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + text_splitter = CharacterTextSplitter.from_huggingface_tokenizer( + tokenizer, separator=" ", chunk_size=1, chunk_overlap=0 + ) + output = text_splitter.split_text("foo bar") + assert output == ["foo", "bar"]