mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
tensoflow_datasets
document loader (#8721)
This PR adds `tensoflow_datasets` document loader
This commit is contained in:
@@ -147,6 +147,7 @@ from langchain.document_loaders.telegram import (
|
||||
)
|
||||
from langchain.document_loaders.tencent_cos_directory import TencentCOSDirectoryLoader
|
||||
from langchain.document_loaders.tencent_cos_file import TencentCOSFileLoader
|
||||
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
|
||||
from langchain.document_loaders.text import TextLoader
|
||||
from langchain.document_loaders.tomarkdown import ToMarkdownLoader
|
||||
from langchain.document_loaders.toml import TomlLoader
|
||||
@@ -299,6 +300,7 @@ __all__ = [
|
||||
"TelegramChatApiLoader",
|
||||
"TelegramChatFileLoader",
|
||||
"TelegramChatLoader",
|
||||
"TensorflowDatasetLoader",
|
||||
"TencentCOSDirectoryLoader",
|
||||
"TencentCOSFileLoader",
|
||||
"TextLoader",
|
||||
|
@@ -8,7 +8,6 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
class ArxivLoader(BaseLoader):
|
||||
"""Loads a query result from arxiv.org into a list of Documents.
|
||||
|
||||
Each document represents one Document.
|
||||
The loader converts the original PDF format into the text.
|
||||
"""
|
||||
|
||||
|
@@ -0,0 +1,79 @@
|
||||
from typing import Callable, Dict, Iterator, List, Optional
|
||||
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.schema import Document
|
||||
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
|
||||
|
||||
|
||||
class TensorflowDatasetLoader(BaseLoader):
|
||||
"""Loads from TensorFlow Datasets into a list of Documents.
|
||||
|
||||
Attributes:
|
||||
dataset_name: the name of the dataset to load
|
||||
split_name: the name of the split to load.
|
||||
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
|
||||
sample_to_document_function: a function that converts a dataset sample
|
||||
into a Document
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.document_loaders import TensorflowDatasetLoader
|
||||
|
||||
def mlqaen_example_to_document(example: dict) -> Document:
|
||||
return Document(
|
||||
page_content=decode_to_str(example["context"]),
|
||||
metadata={
|
||||
"id": decode_to_str(example["id"]),
|
||||
"title": decode_to_str(example["title"]),
|
||||
"question": decode_to_str(example["question"]),
|
||||
"answer": decode_to_str(example["answers"]["text"][0]),
|
||||
},
|
||||
)
|
||||
|
||||
tsds_client = TensorflowDatasetLoader(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="test",
|
||||
load_max_docs=100,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_name: str,
|
||||
split_name: str,
|
||||
load_max_docs: Optional[int] = 100,
|
||||
sample_to_document_function: Optional[Callable[[Dict], Document]] = None,
|
||||
):
|
||||
"""Initialize the TensorflowDatasetLoader.
|
||||
|
||||
Args:
|
||||
dataset_name: the name of the dataset to load
|
||||
split_name: the name of the split to load.
|
||||
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
|
||||
sample_to_document_function: a function that converts a dataset sample
|
||||
into a Document.
|
||||
"""
|
||||
self.dataset_name: str = dataset_name
|
||||
self.split_name: str = split_name
|
||||
self.load_max_docs = load_max_docs
|
||||
"""The maximum number of documents to load."""
|
||||
self.sample_to_document_function: Optional[
|
||||
Callable[[Dict], Document]
|
||||
] = sample_to_document_function
|
||||
"""Custom function that transform a dataset sample into a Document."""
|
||||
|
||||
self._tfds_client = TensorflowDatasets(
|
||||
dataset_name=self.dataset_name,
|
||||
split_name=self.split_name,
|
||||
load_max_docs=self.load_max_docs,
|
||||
sample_to_document_function=self.sample_to_document_function,
|
||||
)
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
yield from self._tfds_client.lazy_load()
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
return list(self.lazy_load())
|
@@ -29,6 +29,7 @@ from langchain.utilities.searx_search import SearxSearchWrapper
|
||||
from langchain.utilities.serpapi import SerpAPIWrapper
|
||||
from langchain.utilities.spark_sql import SparkSQL
|
||||
from langchain.utilities.sql_database import SQLDatabase
|
||||
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
|
||||
from langchain.utilities.twilio import TwilioAPIWrapper
|
||||
from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
|
||||
@@ -62,6 +63,7 @@ __all__ = [
|
||||
"SearxSearchWrapper",
|
||||
"SerpAPIWrapper",
|
||||
"SparkSQL",
|
||||
"TensorflowDatasets",
|
||||
"TextRequestsWrapper",
|
||||
"TextRequestsWrapper",
|
||||
"TwilioAPIWrapper",
|
||||
|
@@ -21,7 +21,7 @@ class ArxivAPIWrapper(BaseModel):
|
||||
It limits the Document content by doc_content_chars_max.
|
||||
Set doc_content_chars_max=None if you don't want to limit the content size.
|
||||
|
||||
Args:
|
||||
Attributes:
|
||||
top_k_results: number of the top-scored document used for the arxiv tool
|
||||
ARXIV_MAX_QUERY_LENGTH: the cut limit on the query used for the arxiv tool.
|
||||
load_max_docs: a limit to the number of loaded documents
|
||||
|
111
libs/langchain/langchain/utilities/tensorflow_datasets.py
Normal file
111
libs/langchain/langchain/utilities/tensorflow_datasets.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional
|
||||
|
||||
from pydantic import BaseModel, root_validator
|
||||
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TensorflowDatasets(BaseModel):
|
||||
"""Access to the TensorFlow Datasets.
|
||||
|
||||
The Current implementation can work only with datasets that fit in a memory.
|
||||
|
||||
`TensorFlow Datasets` is a collection of datasets ready to use, with TensorFlow
|
||||
or other Python ML frameworks, such as Jax. All datasets are exposed
|
||||
as `tf.data.Datasets`.
|
||||
To get started see the Guide: https://www.tensorflow.org/datasets/overview and
|
||||
the list of datasets: https://www.tensorflow.org/datasets/catalog/
|
||||
overview#all_datasets
|
||||
|
||||
You have to provide the sample_to_document_function: a function that
|
||||
a sample from the dataset-specific format to the Document.
|
||||
|
||||
Attributes:
|
||||
dataset_name: the name of the dataset to load
|
||||
split_name: the name of the split to load. Defaults to "train".
|
||||
load_max_docs: a limit to the number of loaded documents. Defaults to 100.
|
||||
sample_to_document_function: a function that converts a dataset sample
|
||||
to a Document
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.utilities import TensorflowDatasets
|
||||
|
||||
def mlqaen_example_to_document(example: dict) -> Document:
|
||||
return Document(
|
||||
page_content=decode_to_str(example["context"]),
|
||||
metadata={
|
||||
"id": decode_to_str(example["id"]),
|
||||
"title": decode_to_str(example["title"]),
|
||||
"question": decode_to_str(example["question"]),
|
||||
"answer": decode_to_str(example["answers"]["text"][0]),
|
||||
},
|
||||
)
|
||||
|
||||
tsds_client = TensorflowDatasets(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="train",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
dataset_name: str = ""
|
||||
split_name: str = "train"
|
||||
load_max_docs: int = 100
|
||||
sample_to_document_function: Optional[Callable[[Dict], Document]] = None
|
||||
dataset: Any #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
import tensorflow # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tensorflow python package. "
|
||||
"Please install it with `pip install tensorflow`."
|
||||
)
|
||||
try:
|
||||
import tensorflow_datasets
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tensorflow_datasets python package. "
|
||||
"Please install it with `pip install tensorflow-datasets`."
|
||||
)
|
||||
if values["sample_to_document_function"] is None:
|
||||
raise ValueError(
|
||||
"sample_to_document_function is None. "
|
||||
"Please provide a function that converts a dataset sample to"
|
||||
" a Document."
|
||||
)
|
||||
values["dataset"] = tensorflow_datasets.load(
|
||||
values["dataset_name"], split=values["split_name"]
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Download a selected dataset lazily.
|
||||
|
||||
Returns: an iterator of Documents.
|
||||
|
||||
"""
|
||||
return (
|
||||
self.sample_to_document_function(s)
|
||||
for s in self.dataset.take(self.load_max_docs)
|
||||
if self.sample_to_document_function is not None
|
||||
)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""Download a selected dataset.
|
||||
|
||||
Returns: a list of Documents.
|
||||
|
||||
"""
|
||||
return list(self.lazy_load())
|
@@ -0,0 +1,105 @@
|
||||
"""Integration tests for the TensorFlow Dataset Loader."""
|
||||
|
||||
import pytest
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
from langchain.document_loaders.tensorflow_datasets import TensorflowDatasetLoader
|
||||
from langchain.schema.document import Document
|
||||
|
||||
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
|
||||
# these tests can be run in isolation only
|
||||
tensorflow = pytest.importorskip("tensorflow")
|
||||
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
|
||||
|
||||
# placed here after checking for tensorflow package installation
|
||||
import tensorflow as tf # noqa: E402
|
||||
|
||||
|
||||
def decode_to_str(item: tf.Tensor) -> str:
|
||||
return item.numpy().decode("utf-8")
|
||||
|
||||
|
||||
def mlqaen_example_to_document(example: dict) -> Document:
|
||||
return Document(
|
||||
page_content=decode_to_str(example["context"]),
|
||||
metadata={
|
||||
"id": decode_to_str(example["id"]),
|
||||
"title": decode_to_str(example["title"]),
|
||||
"question": decode_to_str(example["question"]),
|
||||
"answer": decode_to_str(example["answers"]["text"][0]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
MAX_DOCS = 10
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tfds_client() -> TensorflowDatasetLoader:
|
||||
return TensorflowDatasetLoader(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
|
||||
|
||||
def test_load_success(tfds_client: TensorflowDatasetLoader) -> None:
|
||||
"""Test that returns the correct answer"""
|
||||
|
||||
output = tfds_client.load()
|
||||
assert isinstance(output, list)
|
||||
assert len(output) == MAX_DOCS
|
||||
|
||||
assert isinstance(output[0], Document)
|
||||
assert len(output[0].page_content) > 0
|
||||
assert isinstance(output[0].page_content, str)
|
||||
assert isinstance(output[0].metadata, dict)
|
||||
|
||||
|
||||
def test_lazy_load_success(tfds_client: TensorflowDatasetLoader) -> None:
|
||||
"""Test that returns the correct answer"""
|
||||
|
||||
output = list(tfds_client.lazy_load())
|
||||
assert isinstance(output, list)
|
||||
assert len(output) == MAX_DOCS
|
||||
|
||||
assert isinstance(output[0], Document)
|
||||
assert len(output[0].page_content) > 0
|
||||
assert isinstance(output[0].page_content, str)
|
||||
assert isinstance(output[0].metadata, dict)
|
||||
|
||||
|
||||
def test_load_fail_wrong_dataset_name() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasetLoader(
|
||||
dataset_name="wrong_dataset_name",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
assert "the dataset name is spelled correctly" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_load_fail_wrong_split_name() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasetLoader(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="wrong_split_name",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
assert "Unknown split" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_load_fail_no_func() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasetLoader(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
)
|
||||
assert "Please provide a function" in str(exc_info.value)
|
@@ -0,0 +1,90 @@
|
||||
"""Integration tests for the TensorFlow Dataset client."""
|
||||
|
||||
import pytest
|
||||
import tensorflow as tf
|
||||
from pydantic.error_wrappers import ValidationError
|
||||
|
||||
from langchain.schema.document import Document
|
||||
from langchain.utilities.tensorflow_datasets import TensorflowDatasets
|
||||
|
||||
# adding tensorflow and tensorflow_datasets to pyproject.toml is not working
|
||||
# these tests can be tested in isolation only
|
||||
tensorflow = pytest.importorskip("tensorflow")
|
||||
tensorflow_datasets = pytest.importorskip("tensorflow_datasets")
|
||||
|
||||
|
||||
def decode_to_str(item: tf.Tensor) -> str:
|
||||
return item.numpy().decode("utf-8")
|
||||
|
||||
|
||||
def mlqaen_example_to_document(example: dict) -> Document:
|
||||
return Document(
|
||||
page_content=decode_to_str(example["context"]),
|
||||
metadata={
|
||||
"id": decode_to_str(example["id"]),
|
||||
"title": decode_to_str(example["title"]),
|
||||
"question": decode_to_str(example["question"]),
|
||||
"answer": decode_to_str(example["answers"]["text"][0]),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
MAX_DOCS = 10
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tfds_client() -> TensorflowDatasets:
|
||||
return TensorflowDatasets(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
|
||||
|
||||
def test_load_success(tfds_client: TensorflowDatasets) -> None:
|
||||
"""Test that returns the correct answer"""
|
||||
|
||||
output = tfds_client.load()
|
||||
assert isinstance(output, list)
|
||||
assert len(output) == MAX_DOCS
|
||||
|
||||
assert isinstance(output[0], Document)
|
||||
assert len(output[0].page_content) > 0
|
||||
assert isinstance(output[0].page_content, str)
|
||||
assert isinstance(output[0].metadata, dict)
|
||||
|
||||
|
||||
def test_load_fail_wrong_dataset_name() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasets(
|
||||
dataset_name="wrong_dataset_name",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
assert "the dataset name is spelled correctly" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_load_fail_wrong_split_name() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasets(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="wrong_split_name",
|
||||
load_max_docs=MAX_DOCS,
|
||||
sample_to_document_function=mlqaen_example_to_document,
|
||||
)
|
||||
assert "Unknown split" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_load_fail_no_func() -> None:
|
||||
"""Test that fails to load"""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
TensorflowDatasets(
|
||||
dataset_name="mlqa/en",
|
||||
split_name="test",
|
||||
load_max_docs=MAX_DOCS,
|
||||
)
|
||||
assert "Please provide a function" in str(exc_info.value)
|
Reference in New Issue
Block a user