Fix test_parser_with_table

This commit is contained in:
Philippe Prados 2025-02-26 13:49:46 +01:00
parent be47099747
commit 898e2a5b51

View File

@ -247,55 +247,9 @@ def test_parser_with_table(
mode: str,
extract_tables: str,
) -> None:
from PIL.Image import Image
from langchain_community.document_loaders.parsers.images import BaseImageBlobParser
def _std_assert_with_parser(parser: BaseBlobParser) -> None:
"""Standard tests to verify that the given parser works.
Args:
parser (BaseBlobParser): The parser to test.
"""
blob = Blob.from_path(LAYOUT_PARSER_PAPER_PDF)
doc_generator = parser.lazy_parse(blob)
docs = list(doc_generator)
tables = []
for doc in docs:
if extract_tables == "markdown":
pattern = (
r"(?s)("
r"(?:(?:[^\n]*\|)\n)"
r"(?:\|(?:\s?:?---*:?\s?\|)+)\n"
r"(?:(?:[^\n]*\|)\n)+"
r")"
)
elif extract_tables == "html":
pattern = r"(?s)(<table[^>]*>(?:.*?)<\/table>)"
elif extract_tables == "csv":
pattern = (
r"((?:(?:"
r'(?:"(?:[^"]*(?:""[^"]*)*)"'
r"|[^\n,]*),){2,}"
r"(?:"
r'(?:"(?:[^"]*(?:""[^"]*)*)"'
r"|[^\n]*))\n){2,})"
)
else:
pattern = None
if pattern:
matches = re.findall(pattern, doc.page_content)
if matches:
tables.extend(matches)
if extract_tables:
assert len(tables) >= 1
else:
assert not len(tables)
class EmptyImageBlobParser(BaseImageBlobParser):
def _analyze_image(self, img: Image) -> str:
return "![image](.)"
parser_class = getattr(pdf_parsers, parser_factory)
parser = parser_class(
@ -304,4 +258,46 @@ def test_parser_with_table(
images_parser=EmptyImageBlobParser(),
**params,
)
_std_assert_with_parser(parser)
_std_assert_table_with_parser(extract_tables, parser)
def _std_assert_table_with_parser(extract_tables: str, parser: BaseBlobParser) -> None:
"""Standard tests to verify that the given parser works.
Args:
parser (BaseBlobParser): The parser to test.
"""
blob = Blob.from_path(LAYOUT_PARSER_PAPER_PDF)
doc_generator = parser.lazy_parse(blob)
docs = list(doc_generator)
tables = []
for doc in docs:
if extract_tables == "markdown":
pattern = (
r"(?s)("
r"(?:(?:[^\n]*\|)\n)"
r"(?:\|(?:\s?:?---*:?\s?\|)+)\n"
r"(?:(?:[^\n]*\|)\n)+"
r")"
)
elif extract_tables == "html":
pattern = r"(?s)(<table[^>]*>(?:.*?)<\/table>)"
elif extract_tables == "csv":
pattern = (
r"((?:(?:"
r'(?:"(?:[^"]*(?:""[^"]*)*)"'
r"|[^\n,]*),){2,}"
r"(?:"
r'(?:"(?:[^"]*(?:""[^"]*)*)"'
r"|[^\n]*))\n){2,})"
)
else:
pattern = None
if pattern:
matches = re.findall(pattern, doc.page_content)
if matches:
tables.extend(matches)
if extract_tables:
assert len(tables) >= 1
else:
assert not len(tables)