From ea331f31364266f7a6414dc76265553a52279b0a Mon Sep 17 00:00:00 2001 From: Kane Sweet <71854758+sweetkane@users.noreply.github.com> Date: Mon, 18 Dec 2023 19:15:57 -0600 Subject: [PATCH] Fix token text splitter duplicates (#14848) - **Description:** - Add a break case to `text_splitter.py::split_text_on_tokens()` to avoid unwanted item at the end of result. - Add a testcase to enforce the behavior. - **Issue:** - #14649 - #5897 - **Dependencies:** n/a, --- **Quick illustration of change:** ``` text = "foo bar baz 123" tokenizer = Tokenizer( chunk_overlap=3, tokens_per_chunk=7 ) output = split_text_on_tokens(text=text, tokenizer=tokenizer) ``` output before change: `["foo bar", "bar baz", "baz 123", "123"]` output after change: `["foo bar", "bar baz", "baz 123"]` --- libs/langchain/langchain/text_splitter.py | 2 ++ .../tests/unit_tests/test_text_splitter.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/libs/langchain/langchain/text_splitter.py b/libs/langchain/langchain/text_splitter.py index c0ce1c74fac..efb55c1b984 100644 --- a/libs/langchain/langchain/text_splitter.py +++ b/libs/langchain/langchain/text_splitter.py @@ -670,6 +670,8 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: chunk_ids = input_ids[start_idx:cur_idx] while start_idx < len(input_ids): splits.append(tokenizer.decode(chunk_ids)) + if cur_idx == len(input_ids): + break start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids)) chunk_ids = input_ids[start_idx:cur_idx] diff --git a/libs/langchain/tests/unit_tests/test_text_splitter.py b/libs/langchain/tests/unit_tests/test_text_splitter.py index f09366f1539..2f9cf2ac600 100644 --- a/libs/langchain/tests/unit_tests/test_text_splitter.py +++ b/libs/langchain/tests/unit_tests/test_text_splitter.py @@ -13,6 +13,8 @@ from langchain.text_splitter import ( MarkdownHeaderTextSplitter, PythonCodeTextSplitter, RecursiveCharacterTextSplitter, + Tokenizer, + split_text_on_tokens, ) FAKE_PYTHON_TEXT = """ @@ -1175,3 +1177,18 @@ def test_html_header_text_splitter(tmp_path: Path) -> None: docs_from_file = splitter.split_text_from_file(tmp_path / "doc.html") assert docs_from_file == expected + + +def test_split_text_on_tokens() -> None: + """Test splitting by tokens per chunk.""" + text = "foo bar baz 123" + + tokenizer = Tokenizer( + chunk_overlap=3, + tokens_per_chunk=7, + decode=(lambda it: "".join(chr(i) for i in it)), + encode=(lambda it: [ord(c) for c in it]), + ) + output = split_text_on_tokens(text=text, tokenizer=tokenizer) + expected_output = ["foo bar", "bar baz", "baz 123"] + assert output == expected_output