mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 10:12:33 +00:00
Harrison/fix splitting (#563)
fix issue where text splitting could possibly create empty docs
This commit is contained in:
parent
1192cc0767
commit
1511606799
@ -44,6 +44,14 @@ class TextSplitter(ABC):
|
|||||||
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
documents.append(Document(page_content=chunk, metadata=_metadatas[i]))
|
||||||
return documents
|
return documents
|
||||||
|
|
||||||
|
def _join_docs(self, docs: List[str], separator: str) -> Optional[str]:
|
||||||
|
text = separator.join(docs)
|
||||||
|
text = text.strip()
|
||||||
|
if text == "":
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
|
||||||
# We now want to combine these smaller pieces into medium size
|
# We now want to combine these smaller pieces into medium size
|
||||||
# chunks to send to the LLM.
|
# chunks to send to the LLM.
|
||||||
@ -59,7 +67,9 @@ class TextSplitter(ABC):
|
|||||||
f"which is longer than the specified {self._chunk_size}"
|
f"which is longer than the specified {self._chunk_size}"
|
||||||
)
|
)
|
||||||
if len(current_doc) > 0:
|
if len(current_doc) > 0:
|
||||||
docs.append(separator.join(current_doc))
|
doc = self._join_docs(current_doc, separator)
|
||||||
|
if doc is not None:
|
||||||
|
docs.append(doc)
|
||||||
# Keep on popping if:
|
# Keep on popping if:
|
||||||
# - we have a larger chunk than in the chunk overlap
|
# - we have a larger chunk than in the chunk overlap
|
||||||
# - or if we still have any chunks and the length is long
|
# - or if we still have any chunks and the length is long
|
||||||
@ -70,7 +80,9 @@ class TextSplitter(ABC):
|
|||||||
current_doc = current_doc[1:]
|
current_doc = current_doc[1:]
|
||||||
current_doc.append(d)
|
current_doc.append(d)
|
||||||
total += _len
|
total += _len
|
||||||
docs.append(separator.join(current_doc))
|
doc = self._join_docs(current_doc, separator)
|
||||||
|
if doc is not None:
|
||||||
|
docs.append(doc)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -17,6 +17,15 @@ def test_character_text_splitter() -> None:
|
|||||||
assert output == expected_output
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_character_text_splitter_empty_doc() -> None:
|
||||||
|
"""Test splitting by character count doesn't create empty documents."""
|
||||||
|
text = "foo bar"
|
||||||
|
splitter = CharacterTextSplitter(separator=" ", chunk_size=2, chunk_overlap=0)
|
||||||
|
output = splitter.split_text(text)
|
||||||
|
expected_output = ["foo", "bar"]
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
def test_character_text_splitter_long() -> None:
|
def test_character_text_splitter_long() -> None:
|
||||||
"""Test splitting by character count on long words."""
|
"""Test splitting by character count on long words."""
|
||||||
text = "foo bar baz a a"
|
text = "foo bar baz a a"
|
||||||
|
Loading…
Reference in New Issue
Block a user