From 15b1770326e82e65222800a552b3c77e69ca7ff4 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Wed, 6 Mar 2024 19:16:05 +0100 Subject: [PATCH] Merge pull request #18421 * Implement lazy_load() for AssemblyAIAudioTranscriptLoader --- .../document_loaders/__init__.py | 2 + .../document_loaders/assemblyai.py | 44 ++++++++----------- .../document_loaders/test_assemblyai.py | 41 ++++++++++++++++- .../document_loaders/test_imports.py | 1 + 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/__init__.py b/libs/community/langchain_community/document_loaders/__init__.py index 21d9f39d685..b6abd537a3d 100644 --- a/libs/community/langchain_community/document_loaders/__init__.py +++ b/libs/community/langchain_community/document_loaders/__init__.py @@ -32,6 +32,7 @@ from langchain_community.document_loaders.apify_dataset import ApifyDatasetLoade from langchain_community.document_loaders.arcgis_loader import ArcGISLoader from langchain_community.document_loaders.arxiv import ArxivLoader from langchain_community.document_loaders.assemblyai import ( + AssemblyAIAudioLoaderById, AssemblyAIAudioTranscriptLoader, ) from langchain_community.document_loaders.astradb import AstraDBLoader @@ -262,6 +263,7 @@ __all__ = [ "ApifyDatasetLoader", "ArcGISLoader", "ArxivLoader", + "AssemblyAIAudioLoaderById", "AssemblyAIAudioTranscriptLoader", "AstraDBLoader", "AsyncHtmlLoader", diff --git a/libs/community/langchain_community/document_loaders/assemblyai.py b/libs/community/langchain_community/document_loaders/assemblyai.py index 3b1b4060b98..cc74c552e2d 100644 --- a/libs/community/langchain_community/document_loaders/assemblyai.py +++ b/libs/community/langchain_community/document_loaders/assemblyai.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, Iterator, Optional import requests from langchain_core.documents import Document @@ -75,7 +75,7 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader): self.transcript_format = transcript_format self.transcriber = assemblyai.Transcriber(config=config) - def load(self) -> List[Document]: + def lazy_load(self) -> Iterator[Document]: """Transcribes the audio file and loads the transcript into documents. It uses the AssemblyAI API to transcribe the audio file and blocks until @@ -88,27 +88,21 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader): raise ValueError(f"Could not transcribe file: {transcript.error}") if self.transcript_format == TranscriptFormat.TEXT: - return [ - Document( - page_content=transcript.text, metadata=transcript.json_response - ) - ] + yield Document( + page_content=transcript.text, metadata=transcript.json_response + ) elif self.transcript_format == TranscriptFormat.SENTENCES: sentences = transcript.get_sentences() - return [ - Document(page_content=s.text, metadata=s.dict(exclude={"text"})) - for s in sentences - ] + for s in sentences: + yield Document(page_content=s.text, metadata=s.dict(exclude={"text"})) elif self.transcript_format == TranscriptFormat.PARAGRAPHS: paragraphs = transcript.get_paragraphs() - return [ - Document(page_content=p.text, metadata=p.dict(exclude={"text"})) - for p in paragraphs - ] + for p in paragraphs: + yield Document(page_content=p.text, metadata=p.dict(exclude={"text"})) elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT: - return [Document(page_content=transcript.export_subtitles_srt())] + yield Document(page_content=transcript.export_subtitles_srt()) elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT: - return [Document(page_content=transcript.export_subtitles_vtt())] + yield Document(page_content=transcript.export_subtitles_vtt()) else: raise ValueError("Unknown transcript format.") @@ -140,7 +134,7 @@ class AssemblyAIAudioLoaderById(BaseLoader): self.transcript_id = transcript_id self.transcript_format = transcript_format - def load(self) -> List[Document]: + def lazy_load(self) -> Iterator[Document]: """Load data into Document objects.""" HEADERS = {"authorization": self.api_key} @@ -157,9 +151,7 @@ class AssemblyAIAudioLoaderById(BaseLoader): transcript = transcript_response.json()["text"] - return [ - Document(page_content=transcript, metadata=transcript_response.json()) - ] + yield Document(page_content=transcript, metadata=transcript_response.json()) elif self.transcript_format == TranscriptFormat.PARAGRAPHS: try: paragraphs_response = requests.get( @@ -173,7 +165,8 @@ class AssemblyAIAudioLoaderById(BaseLoader): paragraphs = paragraphs_response.json()["paragraphs"] - return [Document(page_content=p["text"], metadata=p) for p in paragraphs] + for p in paragraphs: + yield Document(page_content=p["text"], metadata=p) elif self.transcript_format == TranscriptFormat.SENTENCES: try: @@ -188,7 +181,8 @@ class AssemblyAIAudioLoaderById(BaseLoader): sentences = sentences_response.json()["sentences"] - return [Document(page_content=s["text"], metadata=s) for s in sentences] + for s in sentences: + yield Document(page_content=s["text"], metadata=s) elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT: try: @@ -203,7 +197,7 @@ class AssemblyAIAudioLoaderById(BaseLoader): srt = srt_response.text - return [Document(page_content=srt)] + yield Document(page_content=srt) elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT: try: @@ -218,6 +212,6 @@ class AssemblyAIAudioLoaderById(BaseLoader): vtt = vtt_response.text - return [Document(page_content=vtt)] + yield Document(page_content=vtt) else: raise ValueError("Unknown transcript format.") diff --git a/libs/community/tests/unit_tests/document_loaders/test_assemblyai.py b/libs/community/tests/unit_tests/document_loaders/test_assemblyai.py index 36226371189..46cdf3a9e7c 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_assemblyai.py +++ b/libs/community/tests/unit_tests/document_loaders/test_assemblyai.py @@ -1,7 +1,12 @@ import pytest +import responses from pytest_mock import MockerFixture +from requests import HTTPError -from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader +from langchain_community.document_loaders import ( + AssemblyAIAudioLoaderById, + AssemblyAIAudioTranscriptLoader, +) from langchain_community.document_loaders.assemblyai import TranscriptFormat @@ -46,3 +51,37 @@ def test_transcription_error(mocker: MockerFixture) -> None: expected_error = "Could not transcribe file: Test error" with pytest.raises(ValueError, match=expected_error): loader.load() + + +@pytest.mark.requires("assemblyai") +@responses.activate +def test_load_by_id() -> None: + responses.add( + responses.GET, + "https://api.assemblyai.com/v2/transcript/1234", + json={"text": "Test transcription text", "id": "1234"}, + status=200, + ) + + loader = AssemblyAIAudioLoaderById( + transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT + ) + docs = loader.load() + assert len(docs) == 1 + assert docs[0].page_content == "Test transcription text" + assert docs[0].metadata == {"text": "Test transcription text", "id": "1234"} + + +@pytest.mark.requires("assemblyai") +@responses.activate +def test_transcription_error_by_id() -> None: + responses.add( + responses.GET, + "https://api.assemblyai.com/v2/transcript/1234", + status=404, + ) + loader = AssemblyAIAudioLoaderById( + transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT + ) + with pytest.raises(HTTPError): + loader.load() diff --git a/libs/community/tests/unit_tests/document_loaders/test_imports.py b/libs/community/tests/unit_tests/document_loaders/test_imports.py index 387bf60a0f6..ad4da720162 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_imports.py +++ b/libs/community/tests/unit_tests/document_loaders/test_imports.py @@ -20,6 +20,7 @@ EXPECTED_ALL = [ "ApifyDatasetLoader", "ArcGISLoader", "ArxivLoader", + "AssemblyAIAudioLoaderById", "AssemblyAIAudioTranscriptLoader", "AstraDBLoader", "AsyncHtmlLoader",