mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 04:55:14 +00:00
text-splitters[minor], langchain[minor], community[patch], templates, docs: langchain-text-splitters 0.0.1 (#18346)
This commit is contained in:
160
libs/text-splitters/langchain_text_splitters/html.py
Normal file
160
libs/text-splitters/langchain_text_splitters/html.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
class ElementType(TypedDict):
|
||||
"""Element type as typed dict."""
|
||||
|
||||
url: str
|
||||
xpath: str
|
||||
content: str
|
||||
metadata: Dict[str, str]
|
||||
|
||||
|
||||
class HTMLHeaderTextSplitter:
|
||||
"""
|
||||
Splitting HTML files based on specified headers.
|
||||
Requires lxml package.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
headers_to_split_on: List[Tuple[str, str]],
|
||||
return_each_element: bool = False,
|
||||
):
|
||||
"""Create a new HTMLHeaderTextSplitter.
|
||||
|
||||
Args:
|
||||
headers_to_split_on: list of tuples of headers we want to track mapped to
|
||||
(arbitrary) keys for metadata. Allowed header values: h1, h2, h3, h4,
|
||||
h5, h6 e.g. [("h1", "Header 1"), ("h2", "Header 2)].
|
||||
return_each_element: Return each element w/ associated headers.
|
||||
"""
|
||||
# Output element-by-element or aggregated into chunks w/ common headers
|
||||
self.return_each_element = return_each_element
|
||||
self.headers_to_split_on = sorted(headers_to_split_on)
|
||||
|
||||
def aggregate_elements_to_chunks(
|
||||
self, elements: List[ElementType]
|
||||
) -> List[Document]:
|
||||
"""Combine elements with common metadata into chunks
|
||||
|
||||
Args:
|
||||
elements: HTML element content with associated identifying info and metadata
|
||||
"""
|
||||
aggregated_chunks: List[ElementType] = []
|
||||
|
||||
for element in elements:
|
||||
if (
|
||||
aggregated_chunks
|
||||
and aggregated_chunks[-1]["metadata"] == element["metadata"]
|
||||
):
|
||||
# If the last element in the aggregated list
|
||||
# has the same metadata as the current element,
|
||||
# append the current content to the last element's content
|
||||
aggregated_chunks[-1]["content"] += " \n" + element["content"]
|
||||
else:
|
||||
# Otherwise, append the current element to the aggregated list
|
||||
aggregated_chunks.append(element)
|
||||
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in aggregated_chunks
|
||||
]
|
||||
|
||||
def split_text_from_url(self, url: str) -> List[Document]:
|
||||
"""Split HTML from web URL
|
||||
|
||||
Args:
|
||||
url: web URL
|
||||
"""
|
||||
r = requests.get(url)
|
||||
return self.split_text_from_file(BytesIO(r.content))
|
||||
|
||||
def split_text(self, text: str) -> List[Document]:
|
||||
"""Split HTML text string
|
||||
|
||||
Args:
|
||||
text: HTML text
|
||||
"""
|
||||
return self.split_text_from_file(StringIO(text))
|
||||
|
||||
def split_text_from_file(self, file: Any) -> List[Document]:
|
||||
"""Split HTML file
|
||||
|
||||
Args:
|
||||
file: HTML file
|
||||
"""
|
||||
try:
|
||||
from lxml import etree
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Unable to import lxml, please install with `pip install lxml`."
|
||||
) from e
|
||||
# use lxml library to parse html document and return xml ElementTree
|
||||
# Explicitly encoding in utf-8 allows non-English
|
||||
# html files to be processed without garbled characters
|
||||
parser = etree.HTMLParser(encoding="utf-8")
|
||||
tree = etree.parse(file, parser)
|
||||
|
||||
# document transformation for "structure-aware" chunking is handled with xsl.
|
||||
# see comments in html_chunks_with_headers.xslt for more detailed information.
|
||||
xslt_path = pathlib.Path(__file__).parent / "xsl/html_chunks_with_headers.xslt"
|
||||
xslt_tree = etree.parse(xslt_path)
|
||||
transform = etree.XSLT(xslt_tree)
|
||||
result = transform(tree)
|
||||
result_dom = etree.fromstring(str(result))
|
||||
|
||||
# create filter and mapping for header metadata
|
||||
header_filter = [header[0] for header in self.headers_to_split_on]
|
||||
header_mapping = dict(self.headers_to_split_on)
|
||||
|
||||
# map xhtml namespace prefix
|
||||
ns_map = {"h": "http://www.w3.org/1999/xhtml"}
|
||||
|
||||
# build list of elements from DOM
|
||||
elements = []
|
||||
for element in result_dom.findall("*//*", ns_map):
|
||||
if element.findall("*[@class='headers']") or element.findall(
|
||||
"*[@class='chunk']"
|
||||
):
|
||||
elements.append(
|
||||
ElementType(
|
||||
url=file,
|
||||
xpath="".join(
|
||||
[
|
||||
node.text or ""
|
||||
for node in element.findall("*[@class='xpath']", ns_map)
|
||||
]
|
||||
),
|
||||
content="".join(
|
||||
[
|
||||
node.text or ""
|
||||
for node in element.findall("*[@class='chunk']", ns_map)
|
||||
]
|
||||
),
|
||||
metadata={
|
||||
# Add text of specified headers to metadata using header
|
||||
# mapping.
|
||||
header_mapping[node.tag]: node.text or ""
|
||||
for node in filter(
|
||||
lambda x: x.tag in header_filter,
|
||||
element.findall("*[@class='headers']/*", ns_map),
|
||||
)
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if not self.return_each_element:
|
||||
return self.aggregate_elements_to_chunks(elements)
|
||||
else:
|
||||
return [
|
||||
Document(page_content=chunk["content"], metadata=chunk["metadata"])
|
||||
for chunk in elements
|
||||
]
|
Reference in New Issue
Block a user