Merge pull request #18421

* Implement lazy_load() for AssemblyAIAudioTranscriptLoader
This commit is contained in:
Christophe Bornet 2024-03-06 19:16:05 +01:00 committed by GitHub
parent bb284eebe4
commit 15b1770326
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 62 additions and 26 deletions

View File

@ -32,6 +32,7 @@ from langchain_community.document_loaders.apify_dataset import ApifyDatasetLoade
from langchain_community.document_loaders.arcgis_loader import ArcGISLoader
from langchain_community.document_loaders.arxiv import ArxivLoader
from langchain_community.document_loaders.assemblyai import (
AssemblyAIAudioLoaderById,
AssemblyAIAudioTranscriptLoader,
)
from langchain_community.document_loaders.astradb import AstraDBLoader
@ -262,6 +263,7 @@ __all__ = [
"ApifyDatasetLoader",
"ArcGISLoader",
"ArxivLoader",
"AssemblyAIAudioLoaderById",
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Iterator, Optional
import requests
from langchain_core.documents import Document
@ -75,7 +75,7 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
self.transcript_format = transcript_format
self.transcriber = assemblyai.Transcriber(config=config)
def load(self) -> List[Document]:
def lazy_load(self) -> Iterator[Document]:
"""Transcribes the audio file and loads the transcript into documents.
It uses the AssemblyAI API to transcribe the audio file and blocks until
@ -88,27 +88,21 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
raise ValueError(f"Could not transcribe file: {transcript.error}")
if self.transcript_format == TranscriptFormat.TEXT:
return [
Document(
page_content=transcript.text, metadata=transcript.json_response
)
]
yield Document(
page_content=transcript.text, metadata=transcript.json_response
)
elif self.transcript_format == TranscriptFormat.SENTENCES:
sentences = transcript.get_sentences()
return [
Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
for s in sentences
]
for s in sentences:
yield Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
paragraphs = transcript.get_paragraphs()
return [
Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
for p in paragraphs
]
for p in paragraphs:
yield Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
return [Document(page_content=transcript.export_subtitles_srt())]
yield Document(page_content=transcript.export_subtitles_srt())
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
return [Document(page_content=transcript.export_subtitles_vtt())]
yield Document(page_content=transcript.export_subtitles_vtt())
else:
raise ValueError("Unknown transcript format.")
@ -140,7 +134,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
self.transcript_id = transcript_id
self.transcript_format = transcript_format
def load(self) -> List[Document]:
def lazy_load(self) -> Iterator[Document]:
"""Load data into Document objects."""
HEADERS = {"authorization": self.api_key}
@ -157,9 +151,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
transcript = transcript_response.json()["text"]
return [
Document(page_content=transcript, metadata=transcript_response.json())
]
yield Document(page_content=transcript, metadata=transcript_response.json())
elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
try:
paragraphs_response = requests.get(
@ -173,7 +165,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
paragraphs = paragraphs_response.json()["paragraphs"]
return [Document(page_content=p["text"], metadata=p) for p in paragraphs]
for p in paragraphs:
yield Document(page_content=p["text"], metadata=p)
elif self.transcript_format == TranscriptFormat.SENTENCES:
try:
@ -188,7 +181,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
sentences = sentences_response.json()["sentences"]
return [Document(page_content=s["text"], metadata=s) for s in sentences]
for s in sentences:
yield Document(page_content=s["text"], metadata=s)
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
try:
@ -203,7 +197,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
srt = srt_response.text
return [Document(page_content=srt)]
yield Document(page_content=srt)
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
try:
@ -218,6 +212,6 @@ class AssemblyAIAudioLoaderById(BaseLoader):
vtt = vtt_response.text
return [Document(page_content=vtt)]
yield Document(page_content=vtt)
else:
raise ValueError("Unknown transcript format.")

View File

@ -1,7 +1,12 @@
import pytest
import responses
from pytest_mock import MockerFixture
from requests import HTTPError
from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader
from langchain_community.document_loaders import (
AssemblyAIAudioLoaderById,
AssemblyAIAudioTranscriptLoader,
)
from langchain_community.document_loaders.assemblyai import TranscriptFormat
@ -46,3 +51,37 @@ def test_transcription_error(mocker: MockerFixture) -> None:
expected_error = "Could not transcribe file: Test error"
with pytest.raises(ValueError, match=expected_error):
loader.load()
@pytest.mark.requires("assemblyai")
@responses.activate
def test_load_by_id() -> None:
responses.add(
responses.GET,
"https://api.assemblyai.com/v2/transcript/1234",
json={"text": "Test transcription text", "id": "1234"},
status=200,
)
loader = AssemblyAIAudioLoaderById(
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
)
docs = loader.load()
assert len(docs) == 1
assert docs[0].page_content == "Test transcription text"
assert docs[0].metadata == {"text": "Test transcription text", "id": "1234"}
@pytest.mark.requires("assemblyai")
@responses.activate
def test_transcription_error_by_id() -> None:
responses.add(
responses.GET,
"https://api.assemblyai.com/v2/transcript/1234",
status=404,
)
loader = AssemblyAIAudioLoaderById(
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
)
with pytest.raises(HTTPError):
loader.load()

View File

@ -20,6 +20,7 @@ EXPECTED_ALL = [
"ApifyDatasetLoader",
"ArcGISLoader",
"ArxivLoader",
"AssemblyAIAudioLoaderById",
"AssemblyAIAudioTranscriptLoader",
"AstraDBLoader",
"AsyncHtmlLoader",