tensoflow_datasets document loader (#8721)

This PR adds `tensoflow_datasets` document loader
This commit is contained in:
Leonid Ganeline
2023-08-08 12:19:28 -07:00
committed by GitHub
parent fad26e79a3
commit 33a2f58fbf
10 changed files with 741 additions and 2 deletions

View File

@@ -147,6 +147,7 @@ from langchain.document_loaders.telegram import (
)
from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader
from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
from langchain.document_loaders.text import TextLoader
from langchain.document_loaders.tomarkdown import ToMarkdownLoader
from langchain.document_loaders.toml import TomlLoader
@@ -299,6 +300,7 @@ __all__ = [
"TelegramChatApiLoader",
"TelegramChatFileLoader",
"TelegramChatLoader",
"TensorflowDatasetLoader",
"TencentCOSDirectoryLoader",
"TencentCOSFileLoader",
"TextLoader",

View File

@@ -8,7 +8,6 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivLoader(BaseLoader):
"""Loads a query result from arxiv.org into a list of Documents.
Each document represents one Document.
The loader converts the original PDF format into the text.
"""

View File

@@ -0,0 +1,79 @@
from typing import Callable, Dict, Iterator, List, Optional
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
class TensorflowDatasetLoader(BaseLoader):
"""Loads from TensorFlow Datasets into a list of Documents.
Attributes:
dataset_name: the name of the dataset to load
split_name: the name of the split to load.
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
into a Document
Example:
.. code-block:: python
from langchain.document_loaders import TensorflowDatasetLoader
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=100,
sample_to_document_function=mlqaen_example_to_document,
)
"""
def __init__(
self,
dataset_name: str,
split_name: str,
load_max_docs: Optional[int] = 100,
sample_to_document_function: Optional[Callable[[Dict], Document]] = None,
):
"""Initialize the TensorflowDatasetLoader.
Args:
dataset_name: the name of the dataset to load
split_name: the name of the split to load.
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
into a Document.
"""
self.dataset_name: str = dataset_name
self.split_name: str = split_name
self.load_max_docs = load_max_docs
"""The maximum number of documents to load."""
self.sample_to_document_function: Optional[
Callable[[Dict], Document]
] = sample_to_document_function
"""Custom function that transform a dataset sample into a Document."""
self._tfds_client = TensorflowDatasets(
dataset_name=self.dataset_name,
split_name=self.split_name,
load_max_docs=self.load_max_docs,
sample_to_document_function=self.sample_to_document_function,
)
def lazy_load(self) -> Iterator[Document]:
yield from self._tfds_client.lazy_load()
def load(self) -> List[Document]:
return list(self.lazy_load())

View File

@@ -29,6 +29,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
from langchain.utilities.serpapi import SerpAPIWrapper
from langchain.utilities.spark_sql import SparkSQL
from langchain.utilities.sql_database import SQLDatabase
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
from langchain.utilities.twilio import TwilioAPIWrapper
from langchain.utilities.wikipedia import WikipediaAPIWrapper
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
@@ -62,6 +63,7 @@ __all__ = [
"SearxSearchWrapper",
"SerpAPIWrapper",
"SparkSQL",
"TensorflowDatasets",
"TextRequestsWrapper",
"TextRequestsWrapper",
"TwilioAPIWrapper",

View File

@@ -21,7 +21,7 @@ class ArxivAPIWrapper(BaseModel):
It limits the Document content by doc_content_chars_max.
Set doc_content_chars_max=None if you don't want to limit the content size.
Args:
Attributes:
top_k_results: number of the top-scored document used for the arxiv tool
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
load_max_docs: a limit to the number of loaded documents

View File

@@ -0,0 +1,111 @@
import logging
from typing import Any, Callable, Dict, Iterator, List, Optional
from pydantic import BaseModel, root_validator
from langchain.schema import Document
logger = logging.getLogger(__name__)
class TensorflowDatasets(BaseModel):
"""Access to the TensorFlow Datasets.
The Current implementation can work only with datasets that fit in a memory.
`TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
or other Python ML frameworks, such as Jax. All datasets are exposed
as `tf.data.Datasets`.
To get started see the Guide: https://www.tensorflow.org/datasets/overview and
the list of datasets: https://www.tensorflow.org/datasets/catalog/
overview#all_datasets
You have to provide the sample_to_document_function: a function that
a sample from the dataset-specific format to the Document.
Attributes:
dataset_name: the name of the dataset to load
split_name: the name of the split to load. Defaults to "train".
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
sample_to_document_function: a function that converts a dataset sample
to a Document
Example:
.. code-block:: python
from langchain.utilities import TensorflowDatasets
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
tsds_client = TensorflowDatasets(
dataset_name="mlqa/en",
split_name="train",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
"""
dataset_name: str = ""
split_name: str = "train"
load_max_docs: int = 100
sample_to_document_function: Optional[Callable[[Dict], Document]] = None
dataset: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in environment."""
try:
import tensorflow # noqa: F401
except ImportError:
raise ImportError(
"Could not import tensorflow python package. "
"Please install it with `pip install tensorflow`."
)
try:
import tensorflow_datasets
except ImportError:
raise ImportError(
"Could not import tensorflow_datasets python package. "
"Please install it with `pip install tensorflow-datasets`."
)
if values["sample_to_document_function"] is None:
raise ValueError(
"sample_to_document_function is None. "
"Please provide a function that converts a dataset sample to"
" a Document."
)
values["dataset"] = tensorflow_datasets.load(
values["dataset_name"], split=values["split_name"]
)
return values
def lazy_load(self) -> Iterator[Document]:
"""Download a selected dataset lazily.
Returns: an iterator of Documents.
"""
return (
self.sample_to_document_function(s)
for s in self.dataset.take(self.load_max_docs)
if self.sample_to_document_function is not None
)
def load(self) -> List[Document]:
"""Download a selected dataset.
Returns: a list of Documents.
"""
return list(self.lazy_load())

View File

@@ -0,0 +1,105 @@
"""Integration tests for the TensorFlow Dataset Loader."""
import pytest
from pydantic.error_wrappers import ValidationError
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
from langchain.schema.document import Document
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
# these tests can be run in isolation only
tensorflow = pytest.importorskip("tensorflow")
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
# placed here after checking for tensorflow package installation
import tensorflow as tf # noqa: E402
def decode_to_str(item: tf.Tensor) -> str:
return item.numpy().decode("utf-8")
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
MAX_DOCS = 10
@pytest.fixture
def tfds_client() -> TensorflowDatasetLoader:
return TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
def test_load_success(tfds_client: TensorflowDatasetLoader) -> None:
"""Test that returns the correct answer"""
output = tfds_client.load()
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_lazy_load_success(tfds_client: TensorflowDatasetLoader) -> None:
"""Test that returns the correct answer"""
output = list(tfds_client.lazy_load())
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_load_fail_wrong_dataset_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="wrong_dataset_name",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "the dataset name is spelled correctly" in str(exc_info.value)
def test_load_fail_wrong_split_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="wrong_split_name",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "Unknown split" in str(exc_info.value)
def test_load_fail_no_func() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasetLoader(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
)
assert "Please provide a function" in str(exc_info.value)

View File

@@ -0,0 +1,90 @@
"""Integration tests for the TensorFlow Dataset client."""
import pytest
import tensorflow as tf
from pydantic.error_wrappers import ValidationError
from langchain.schema.document import Document
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
# these tests can be tested in isolation only
tensorflow = pytest.importorskip("tensorflow")
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
def decode_to_str(item: tf.Tensor) -> str:
return item.numpy().decode("utf-8")
def mlqaen_example_to_document(example: dict) -> Document:
return Document(
page_content=decode_to_str(example["context"]),
metadata={
"id": decode_to_str(example["id"]),
"title": decode_to_str(example["title"]),
"question": decode_to_str(example["question"]),
"answer": decode_to_str(example["answers"]["text"][0]),
},
)
MAX_DOCS = 10
@pytest.fixture
def tfds_client() -> TensorflowDatasets:
return TensorflowDatasets(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
def test_load_success(tfds_client: TensorflowDatasets) -> None:
"""Test that returns the correct answer"""
output = tfds_client.load()
assert isinstance(output, list)
assert len(output) == MAX_DOCS
assert isinstance(output[0], Document)
assert len(output[0].page_content) > 0
assert isinstance(output[0].page_content, str)
assert isinstance(output[0].metadata, dict)
def test_load_fail_wrong_dataset_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="wrong_dataset_name",
split_name="test",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "the dataset name is spelled correctly" in str(exc_info.value)
def test_load_fail_wrong_split_name() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="mlqa/en",
split_name="wrong_split_name",
load_max_docs=MAX_DOCS,
sample_to_document_function=mlqaen_example_to_document,
)
assert "Unknown split" in str(exc_info.value)
def test_load_fail_no_func() -> None:
"""Test that fails to load"""
with pytest.raises(ValidationError) as exc_info:
TensorflowDatasets(
dataset_name="mlqa/en",
split_name="test",
load_max_docs=MAX_DOCS,
)
assert "Please provide a function" in str(exc_info.value)