refactor(langchain): improve type annotations in url_playwright and its test

This commit is contained in:
Youngwook Kim 2023-08-09 15:56:24 +09:00
parent 04fcd2d2e0
commit 429de77b3b
2 changed files with 34 additions and 16 deletions

View File

@ -2,11 +2,16 @@
""" """
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.base import BaseLoader
if TYPE_CHECKING:
from playwright.async_api import AsyncBrowser, AsyncPage, AsyncResponse
from playwright.sync_api import Browser, Page, Response
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -18,7 +23,7 @@ class PlaywrightEvaluator(ABC):
""" """
@abstractmethod @abstractmethod
def evaluate(self, page, browser, response): def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
"""Synchronously process the page and return the resulting text. """Synchronously process the page and return the resulting text.
Args: Args:
@ -32,7 +37,9 @@ class PlaywrightEvaluator(ABC):
pass pass
@abstractmethod @abstractmethod
async def evaluate_async(self, page, browser, response): async def evaluate_async(
self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
"""Asynchronously process the page and return the resulting text. """Asynchronously process the page and return the resulting text.
Args: Args:
@ -50,7 +57,7 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
"""Evaluates the page HTML content using the `unstructured` library.""" """Evaluates the page HTML content using the `unstructured` library."""
def __init__(self, remove_selectors: Optional[List[str]] = None): def __init__(self, remove_selectors: Optional[List[str]] = None):
"""Initialize UnstructuredHtmlEvaluator and check if `unstructured` package is installed.""" """Initialize UnstructuredHtmlEvaluator."""
try: try:
import unstructured # noqa:F401 import unstructured # noqa:F401
except ImportError: except ImportError:
@ -61,8 +68,8 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
self.remove_selectors = remove_selectors self.remove_selectors = remove_selectors
def evaluate(self, page, browser, response): def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
"""Synchronously process the HTML content of the page and return a text string.""" """Synchronously process the HTML content of the page."""
from unstructured.partition.html import partition_html from unstructured.partition.html import partition_html
for selector in self.remove_selectors or []: for selector in self.remove_selectors or []:
@ -75,8 +82,10 @@ class UnstructuredHtmlEvaluator(PlaywrightEvaluator):
elements = partition_html(text=page_source) elements = partition_html(text=page_source)
return "\n\n".join([str(el) for el in elements]) return "\n\n".join([str(el) for el in elements])
async def evaluate_async(self, page, browser, response): async def evaluate_async(
"""Asynchronously process the HTML content of the page and return a text string.""" self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
"""Asynchronously process the HTML content of the page."""
from unstructured.partition.html import partition_html from unstructured.partition.html import partition_html
for selector in self.remove_selectors or []: for selector in self.remove_selectors or []:
@ -126,7 +135,7 @@ class PlaywrightURLLoader(BaseLoader):
"`remove_selectors` and `evaluator` cannot be both not None" "`remove_selectors` and `evaluator` cannot be both not None"
) )
# Use the provided evaluator, if any, otherwise, use the default UnstructuredHtmlEvaluator. # Use the provided evaluator, if any, otherwise, use the default.
self.evaluator = evaluator or UnstructuredHtmlEvaluator(remove_selectors) self.evaluator = evaluator or UnstructuredHtmlEvaluator(remove_selectors)
def load(self) -> List[Document]: def load(self) -> List[Document]:

View File

@ -1,16 +1,25 @@
"""Tests for the Playwright URL loader""" """Tests for the Playwright URL loader"""
from typing import TYPE_CHECKING
import pytest import pytest
from langchain.document_loaders import PlaywrightURLLoader from langchain.document_loaders import PlaywrightURLLoader
from langchain.document_loaders.url_playwright import PlaywrightEvaluator
if TYPE_CHECKING:
from playwright.async_api import AsyncBrowser, AsyncPage, AsyncResponse
from playwright.sync_api import Browser, Page, Response
class TestEvaluator(PageEvaluator): class TestEvaluator(PlaywrightEvaluator):
"""A simple evaluator for testing purposes.""" """A simple evaluator for testing purposes."""
def evaluate(self, page, browser, response): def evaluate(self, page: "Page", browser: "Browser", response: "Response") -> str:
return "test" return "test"
async def evaluate_async(self, page, browser, response): async def evaluate_async(
self, page: "AsyncPage", browser: "AsyncBrowser", response: "AsyncResponse"
) -> str:
return "test" return "test"
@ -56,13 +65,13 @@ def test_playwright_url_loader_with_custom_evaluator() -> None:
urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"] urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"]
loader = PlaywrightURLLoader( loader = PlaywrightURLLoader(
urls=urls, urls=urls,
page_evaluator=TestEvaluator(), evaluator=TestEvaluator(),
continue_on_failure=False, continue_on_failure=False,
headless=True, headless=True,
) )
docs = loader.load() docs = loader.load()
assert len(docs) == 1 assert len(docs) == 1
assert docs[0].page_content == "test-" assert docs[0].page_content == "test"
@pytest.mark.asyncio @pytest.mark.asyncio
@ -71,10 +80,10 @@ async def test_playwright_async_url_loader_with_custom_evaluator() -> None:
urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"] urls = ["https://www.youtube.com/watch?v=dQw4w9WgXcQ"]
loader = PlaywrightURLLoader( loader = PlaywrightURLLoader(
urls=urls, urls=urls,
page_evaluator=TestEvaluator(), evaluator=TestEvaluator(),
continue_on_failure=False, continue_on_failure=False,
headless=True, headless=True,
) )
docs = await loader.aload() docs = await loader.aload()
assert len(docs) == 2 assert len(docs) == 1
assert docs[0].page_content == "test" assert docs[0].page_content == "test"