mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
fix text splitter (#375)
This commit is contained in:
parent
3474f39e21
commit
e7b625fe03
@ -1,9 +1,12 @@
|
|||||||
"""Functionality for splitting text."""
|
"""Functionality for splitting text."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Iterable, List
|
from typing import Any, Callable, Iterable, List
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
|
||||||
class TextSplitter(ABC):
|
class TextSplitter(ABC):
|
||||||
"""Interface for splitting text into chunks."""
|
"""Interface for splitting text into chunks."""
|
||||||
@ -37,13 +40,20 @@ class TextSplitter(ABC):
|
|||||||
current_doc: List[str] = []
|
current_doc: List[str] = []
|
||||||
total = 0
|
total = 0
|
||||||
for d in splits:
|
for d in splits:
|
||||||
if total >= self._chunk_size:
|
_len = self._length_function(d)
|
||||||
docs.append(self._separator.join(current_doc))
|
if total + _len >= self._chunk_size:
|
||||||
while total > self._chunk_overlap:
|
if total > self._chunk_size:
|
||||||
total -= self._length_function(current_doc[0])
|
logger.warning(
|
||||||
current_doc = current_doc[1:]
|
f"Created a chunk of size {total}, "
|
||||||
|
f"which is longer than the specified {self._chunk_size}"
|
||||||
|
)
|
||||||
|
if len(current_doc) > 0:
|
||||||
|
docs.append(self._separator.join(current_doc))
|
||||||
|
while total > self._chunk_overlap:
|
||||||
|
total -= self._length_function(current_doc[0])
|
||||||
|
current_doc = current_doc[1:]
|
||||||
current_doc.append(d)
|
current_doc.append(d)
|
||||||
total += self._length_function(d)
|
total += _len
|
||||||
docs.append(self._separator.join(current_doc))
|
docs.append(self._separator.join(current_doc))
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
@ -7,12 +7,21 @@ from langchain.text_splitter import CharacterTextSplitter
|
|||||||
def test_character_text_splitter() -> None:
|
def test_character_text_splitter() -> None:
|
||||||
"""Test splitting by character count."""
|
"""Test splitting by character count."""
|
||||||
text = "foo bar baz 123"
|
text = "foo bar baz 123"
|
||||||
splitter = CharacterTextSplitter(separator=" ", chunk_size=5, chunk_overlap=3)
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=7, chunk_overlap=3)
|
||||||
output = splitter.split_text(text)
|
output = splitter.split_text(text)
|
||||||
expected_output = ["foo bar", "bar baz", "baz 123"]
|
expected_output = ["foo bar", "bar baz", "baz 123"]
|
||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_character_text_splitter_long() -> None:
|
||||||
|
"""Test splitting by character count on long words."""
|
||||||
|
text = "foo bar baz a a"
|
||||||
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=3, chunk_overlap=1)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["foo", "bar", "baz", "a a"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_longer_words() -> None:
|
def test_character_text_splitter_longer_words() -> None:
|
||||||
"""Test splitting by characters when splits not found easily."""
|
"""Test splitting by characters when splits not found easily."""
|
||||||
text = "foo bar baz 123"
|
text = "foo bar baz 123"
|
||||||
|
Loading…
Reference in New Issue
Block a user