Merge pull request #18421

* Implement lazy_load() for AssemblyAIAudioTranscriptLoader
This commit is contained in:
Christophe Bornet 2024-03-06 19:16:05 +01:00 committed by GitHub
parent bb284eebe4
commit 15b1770326
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 26 deletions

View File

@ -32,6 +32,7 @@ from langchain_community.document_loaders.apify_dataset import ApifyDatasetLoade
from langchain_community.document_loaders.arcgis_loader import ArcGISLoader from langchain_community.document_loaders.arcgis_loader import ArcGISLoader
from langchain_community.document_loaders.arxiv import ArxivLoader from langchain_community.document_loaders.arxiv import ArxivLoader
from langchain_community.document_loaders.assemblyai import ( from langchain_community.document_loaders.assemblyai import (
AssemblyAIAudioLoaderById,
AssemblyAIAudioTranscriptLoader, AssemblyAIAudioTranscriptLoader,
) )
from langchain_community.document_loaders.astradb import AstraDBLoader from langchain_community.document_loaders.astradb import AstraDBLoader
@ -262,6 +263,7 @@ __all__ = [
"ApifyDatasetLoader", "ApifyDatasetLoader",
"ArcGISLoader", "ArcGISLoader",
"ArxivLoader", "ArxivLoader",
"AssemblyAIAudioLoaderById",
"AssemblyAIAudioTranscriptLoader", "AssemblyAIAudioTranscriptLoader",
"AstraDBLoader", "AstraDBLoader",
"AsyncHtmlLoader", "AsyncHtmlLoader",

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Iterator, Optional
import requests import requests
from langchain_core.documents import Document from langchain_core.documents import Document
@ -75,7 +75,7 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
self.transcript_format = transcript_format self.transcript_format = transcript_format
self.transcriber = assemblyai.Transcriber(config=config) self.transcriber = assemblyai.Transcriber(config=config)
def load(self) -> List[Document]: def lazy_load(self) -> Iterator[Document]:
"""Transcribes the audio file and loads the transcript into documents. """Transcribes the audio file and loads the transcript into documents.
It uses the AssemblyAI API to transcribe the audio file and blocks until It uses the AssemblyAI API to transcribe the audio file and blocks until
@ -88,27 +88,21 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
raise ValueError(f"Could not transcribe file: {transcript.error}") raise ValueError(f"Could not transcribe file: {transcript.error}")
if self.transcript_format == TranscriptFormat.TEXT: if self.transcript_format == TranscriptFormat.TEXT:
return [ yield Document(
Document( page_content=transcript.text, metadata=transcript.json_response
page_content=transcript.text, metadata=transcript.json_response )
)
]
elif self.transcript_format == TranscriptFormat.SENTENCES: elif self.transcript_format == TranscriptFormat.SENTENCES:
sentences = transcript.get_sentences() sentences = transcript.get_sentences()
return [ for s in sentences:
Document(page_content=s.text, metadata=s.dict(exclude={"text"})) yield Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
for s in sentences
]
elif self.transcript_format == TranscriptFormat.PARAGRAPHS: elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
paragraphs = transcript.get_paragraphs() paragraphs = transcript.get_paragraphs()
return [ for p in paragraphs:
Document(page_content=p.text, metadata=p.dict(exclude={"text"})) yield Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
for p in paragraphs
]
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT: elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
return [Document(page_content=transcript.export_subtitles_srt())] yield Document(page_content=transcript.export_subtitles_srt())
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT: elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
return [Document(page_content=transcript.export_subtitles_vtt())] yield Document(page_content=transcript.export_subtitles_vtt())
else: else:
raise ValueError("Unknown transcript format.") raise ValueError("Unknown transcript format.")
@ -140,7 +134,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
self.transcript_id = transcript_id self.transcript_id = transcript_id
self.transcript_format = transcript_format self.transcript_format = transcript_format
def load(self) -> List[Document]: def lazy_load(self) -> Iterator[Document]:
"""Load data into Document objects.""" """Load data into Document objects."""
HEADERS = {"authorization": self.api_key} HEADERS = {"authorization": self.api_key}
@ -157,9 +151,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
transcript = transcript_response.json()["text"] transcript = transcript_response.json()["text"]
return [ yield Document(page_content=transcript, metadata=transcript_response.json())
Document(page_content=transcript, metadata=transcript_response.json())
]
elif self.transcript_format == TranscriptFormat.PARAGRAPHS: elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
try: try:
paragraphs_response = requests.get( paragraphs_response = requests.get(
@ -173,7 +165,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
paragraphs = paragraphs_response.json()["paragraphs"] paragraphs = paragraphs_response.json()["paragraphs"]
return [Document(page_content=p["text"], metadata=p) for p in paragraphs] for p in paragraphs:
yield Document(page_content=p["text"], metadata=p)
elif self.transcript_format == TranscriptFormat.SENTENCES: elif self.transcript_format == TranscriptFormat.SENTENCES:
try: try:
@ -188,7 +181,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
sentences = sentences_response.json()["sentences"] sentences = sentences_response.json()["sentences"]
return [Document(page_content=s["text"], metadata=s) for s in sentences] for s in sentences:
yield Document(page_content=s["text"], metadata=s)
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT: elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
try: try:
@ -203,7 +197,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
srt = srt_response.text srt = srt_response.text
return [Document(page_content=srt)] yield Document(page_content=srt)
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT: elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
try: try:
@ -218,6 +212,6 @@ class AssemblyAIAudioLoaderById(BaseLoader):
vtt = vtt_response.text vtt = vtt_response.text
return [Document(page_content=vtt)] yield Document(page_content=vtt)
else: else:
raise ValueError("Unknown transcript format.") raise ValueError("Unknown transcript format.")

View File

@ -1,7 +1,12 @@
import pytest import pytest
import responses
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from requests import HTTPError
from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader from langchain_community.document_loaders import (
AssemblyAIAudioLoaderById,
AssemblyAIAudioTranscriptLoader,
)
from langchain_community.document_loaders.assemblyai import TranscriptFormat from langchain_community.document_loaders.assemblyai import TranscriptFormat
@ -46,3 +51,37 @@ def test_transcription_error(mocker: MockerFixture) -> None:
expected_error = "Could not transcribe file: Test error" expected_error = "Could not transcribe file: Test error"
with pytest.raises(ValueError, match=expected_error): with pytest.raises(ValueError, match=expected_error):
loader.load() loader.load()
@pytest.mark.requires("assemblyai")
@responses.activate
def test_load_by_id() -> None:
responses.add(
responses.GET,
"https://api.assemblyai.com/v2/transcript/1234",
json={"text": "Test transcription text", "id": "1234"},
status=200,
)
loader = AssemblyAIAudioLoaderById(
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "Test transcription text"
assert docs[0].metadata == {"text": "Test transcription text", "id": "1234"}
@pytest.mark.requires("assemblyai")
@responses.activate
def test_transcription_error_by_id() -> None:
responses.add(
responses.GET,
"https://api.assemblyai.com/v2/transcript/1234",
status=404,
)
loader = AssemblyAIAudioLoaderById(
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
)
with pytest.raises(HTTPError):
loader.load()

View File

@ -20,6 +20,7 @@ EXPECTED_ALL = [
"ApifyDatasetLoader", "ApifyDatasetLoader",
"ArcGISLoader", "ArcGISLoader",
"ArxivLoader", "ArxivLoader",
"AssemblyAIAudioLoaderById",
"AssemblyAIAudioTranscriptLoader", "AssemblyAIAudioTranscriptLoader",
"AstraDBLoader", "AstraDBLoader",
"AsyncHtmlLoader", "AsyncHtmlLoader",