community[patch]: add NotebookLoader unit test (#17721)

Thank you for contributing to LangChain!

- **Description:** added unit tests for NotebookLoader. Linked PR:
https://github.com/langchain-ai/langchain/pull/17614
- **Issue:**
[#17614](https://github.com/langchain-ai/langchain/pull/17614)
    - **Twitter handle:** @paulodoestech
- [x] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [x] Add tests and docs: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.

---------

Co-authored-by: lachiewalker <lachiewalker1@hotmail.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Paulo Nascimento 2024-03-28 17:27:46 -07:00 committed by GitHub
parent 4c3a67122f
commit 44a3484503
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 109 additions and 21 deletions

View File

@ -25,7 +25,11 @@ def concatenate_cells(
""" """
cell_type = cell["cell_type"] cell_type = cell["cell_type"]
source = cell["source"] source = cell["source"]
output = cell["outputs"] if include_outputs:
try:
output = cell["outputs"]
except KeyError:
pass
if include_outputs and cell_type == "code" and output: if include_outputs and cell_type == "code" and output:
if "ename" in output[0].keys(): if "ename" in output[0].keys():
@ -58,14 +62,13 @@ def concatenate_cells(
def remove_newlines(x: Any) -> Any: def remove_newlines(x: Any) -> Any:
"""Recursively remove newlines, no matter the data structure they are stored in.""" """Recursively remove newlines, no matter the data structure they are stored in."""
import pandas as pd
if isinstance(x, str): if isinstance(x, str):
return x.replace("\n", "") return x.replace("\n", "")
elif isinstance(x, list): elif isinstance(x, list):
return [remove_newlines(elem) for elem in x] return [remove_newlines(elem) for elem in x]
elif isinstance(x, pd.DataFrame): elif isinstance(x, dict):
return x.applymap(remove_newlines) return {k: remove_newlines(v) for (k, v) in x.items()}
else: else:
return x return x
@ -104,29 +107,29 @@ class NotebookLoader(BaseLoader):
self, self,
) -> List[Document]: ) -> List[Document]:
"""Load documents.""" """Load documents."""
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is needed for Notebook Loader, "
"please install with `pip install pandas`"
)
p = Path(self.file_path) p = Path(self.file_path)
with open(p, encoding="utf8") as f: with open(p, encoding="utf8") as f:
d = json.load(f) d = json.load(f)
data = pd.json_normalize(d["cells"]) filtered_data = [
filtered_data = data[["cell_type", "source", "outputs"]] {k: v for (k, v) in cell.items() if k in ["cell_type", "source", "outputs"]}
if self.remove_newline: for cell in d["cells"]
filtered_data = filtered_data.applymap(remove_newlines) ]
text = filtered_data.apply( if self.remove_newline:
lambda x: concatenate_cells( filtered_data = list(map(remove_newlines, filtered_data))
x, self.include_outputs, self.max_output_length, self.traceback
), text = "".join(
axis=1, list(
).str.cat(sep=" ") map(
lambda x: concatenate_cells(
x, self.include_outputs, self.max_output_length, self.traceback
),
filtered_data,
)
)
)
metadata = {"source": str(p)} metadata = {"source": str(p)}

View File

@ -0,0 +1,85 @@
import json
from pytest_mock import MockerFixture
from langchain_community.document_loaders.notebook import NotebookLoader
def test_initialization() -> None:
loader = NotebookLoader(path="./testfile.ipynb")
assert loader.file_path == "./testfile.ipynb"
def test_load_no_outputs(mocker: MockerFixture) -> None:
mock_notebook_content = {
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": ["# Test notebook\n", "This is a test notebook."],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": ["Hello World!\n"],
}
],
}
]
}
mocked_cell_type = mock_notebook_content["cells"][0]["cell_type"]
mocked_source = mock_notebook_content["cells"][0]["source"]
# Convert the mock notebook content to a JSON string
mock_notebook_content_str = json.dumps(mock_notebook_content)
# Mock the open function & json.load functions
mocker.patch("builtins.open", mocker.mock_open(read_data=mock_notebook_content_str))
mocker.patch("json.load", return_value=mock_notebook_content)
loader = NotebookLoader(path="./testfile.ipynb")
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == f"'{mocked_cell_type}' cell: '{mocked_source}'\n\n"
assert docs[0].metadata == {"source": "testfile.ipynb"}
def test_load_with_outputs(mocker: MockerFixture) -> None:
mock_notebook_content: dict = {
"cells": [
{
"cell_type": "code",
"metadata": {},
"source": ["# Test notebook\n", "This is a test notebook."],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": ["Hello World!\n"],
}
],
}
]
}
mocked_cell_type = mock_notebook_content["cells"][0]["cell_type"]
mocked_source = mock_notebook_content["cells"][0]["source"]
mocked_output = mock_notebook_content["cells"][0]["outputs"][0]["text"]
# Convert the mock notebook content to a JSON string
mock_notebook_content_str = json.dumps(mock_notebook_content)
# Mock the open function & json.load functions
mocker.patch("builtins.open", mocker.mock_open(read_data=mock_notebook_content_str))
mocker.patch("json.load", return_value=mock_notebook_content)
loader = NotebookLoader(path="./testfile.ipynb", include_outputs=True)
docs = loader.load()
assert len(docs) == 1
expected_content = (
f"'{mocked_cell_type}' cell: '{mocked_source}'\n"
f" with output: '{mocked_output}'\n\n"
)
assert docs[0].page_content == expected_content
assert docs[0].metadata == {"source": "testfile.ipynb"}