mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
Merge pull request #18421
* Implement lazy_load() for AssemblyAIAudioTranscriptLoader
This commit is contained in:
parent
bb284eebe4
commit
15b1770326
@ -32,6 +32,7 @@ from langchain_community.document_loaders.apify_dataset import ApifyDatasetLoade
|
||||
from langchain_community.document_loaders.arcgis_loader import ArcGISLoader
|
||||
from langchain_community.document_loaders.arxiv import ArxivLoader
|
||||
from langchain_community.document_loaders.assemblyai import (
|
||||
AssemblyAIAudioLoaderById,
|
||||
AssemblyAIAudioTranscriptLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.astradb import AstraDBLoader
|
||||
@ -262,6 +263,7 @@ __all__ = [
|
||||
"ApifyDatasetLoader",
|
||||
"ArcGISLoader",
|
||||
"ArxivLoader",
|
||||
"AssemblyAIAudioLoaderById",
|
||||
"AssemblyAIAudioTranscriptLoader",
|
||||
"AstraDBLoader",
|
||||
"AsyncHtmlLoader",
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Iterator, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
@ -75,7 +75,7 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
|
||||
self.transcript_format = transcript_format
|
||||
self.transcriber = assemblyai.Transcriber(config=config)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Transcribes the audio file and loads the transcript into documents.
|
||||
|
||||
It uses the AssemblyAI API to transcribe the audio file and blocks until
|
||||
@ -88,27 +88,21 @@ class AssemblyAIAudioTranscriptLoader(BaseLoader):
|
||||
raise ValueError(f"Could not transcribe file: {transcript.error}")
|
||||
|
||||
if self.transcript_format == TranscriptFormat.TEXT:
|
||||
return [
|
||||
Document(
|
||||
page_content=transcript.text, metadata=transcript.json_response
|
||||
)
|
||||
]
|
||||
yield Document(
|
||||
page_content=transcript.text, metadata=transcript.json_response
|
||||
)
|
||||
elif self.transcript_format == TranscriptFormat.SENTENCES:
|
||||
sentences = transcript.get_sentences()
|
||||
return [
|
||||
Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
|
||||
for s in sentences
|
||||
]
|
||||
for s in sentences:
|
||||
yield Document(page_content=s.text, metadata=s.dict(exclude={"text"}))
|
||||
elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
|
||||
paragraphs = transcript.get_paragraphs()
|
||||
return [
|
||||
Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
|
||||
for p in paragraphs
|
||||
]
|
||||
for p in paragraphs:
|
||||
yield Document(page_content=p.text, metadata=p.dict(exclude={"text"}))
|
||||
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
|
||||
return [Document(page_content=transcript.export_subtitles_srt())]
|
||||
yield Document(page_content=transcript.export_subtitles_srt())
|
||||
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
|
||||
return [Document(page_content=transcript.export_subtitles_vtt())]
|
||||
yield Document(page_content=transcript.export_subtitles_vtt())
|
||||
else:
|
||||
raise ValueError("Unknown transcript format.")
|
||||
|
||||
@ -140,7 +134,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
self.transcript_id = transcript_id
|
||||
self.transcript_format = transcript_format
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Load data into Document objects."""
|
||||
HEADERS = {"authorization": self.api_key}
|
||||
|
||||
@ -157,9 +151,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
|
||||
transcript = transcript_response.json()["text"]
|
||||
|
||||
return [
|
||||
Document(page_content=transcript, metadata=transcript_response.json())
|
||||
]
|
||||
yield Document(page_content=transcript, metadata=transcript_response.json())
|
||||
elif self.transcript_format == TranscriptFormat.PARAGRAPHS:
|
||||
try:
|
||||
paragraphs_response = requests.get(
|
||||
@ -173,7 +165,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
|
||||
paragraphs = paragraphs_response.json()["paragraphs"]
|
||||
|
||||
return [Document(page_content=p["text"], metadata=p) for p in paragraphs]
|
||||
for p in paragraphs:
|
||||
yield Document(page_content=p["text"], metadata=p)
|
||||
|
||||
elif self.transcript_format == TranscriptFormat.SENTENCES:
|
||||
try:
|
||||
@ -188,7 +181,8 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
|
||||
sentences = sentences_response.json()["sentences"]
|
||||
|
||||
return [Document(page_content=s["text"], metadata=s) for s in sentences]
|
||||
for s in sentences:
|
||||
yield Document(page_content=s["text"], metadata=s)
|
||||
|
||||
elif self.transcript_format == TranscriptFormat.SUBTITLES_SRT:
|
||||
try:
|
||||
@ -203,7 +197,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
|
||||
srt = srt_response.text
|
||||
|
||||
return [Document(page_content=srt)]
|
||||
yield Document(page_content=srt)
|
||||
|
||||
elif self.transcript_format == TranscriptFormat.SUBTITLES_VTT:
|
||||
try:
|
||||
@ -218,6 +212,6 @@ class AssemblyAIAudioLoaderById(BaseLoader):
|
||||
|
||||
vtt = vtt_response.text
|
||||
|
||||
return [Document(page_content=vtt)]
|
||||
yield Document(page_content=vtt)
|
||||
else:
|
||||
raise ValueError("Unknown transcript format.")
|
||||
|
@ -1,7 +1,12 @@
|
||||
import pytest
|
||||
import responses
|
||||
from pytest_mock import MockerFixture
|
||||
from requests import HTTPError
|
||||
|
||||
from langchain_community.document_loaders import AssemblyAIAudioTranscriptLoader
|
||||
from langchain_community.document_loaders import (
|
||||
AssemblyAIAudioLoaderById,
|
||||
AssemblyAIAudioTranscriptLoader,
|
||||
)
|
||||
from langchain_community.document_loaders.assemblyai import TranscriptFormat
|
||||
|
||||
|
||||
@ -46,3 +51,37 @@ def test_transcription_error(mocker: MockerFixture) -> None:
|
||||
expected_error = "Could not transcribe file: Test error"
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
loader.load()
|
||||
|
||||
|
||||
@pytest.mark.requires("assemblyai")
|
||||
@responses.activate
|
||||
def test_load_by_id() -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"https://api.assemblyai.com/v2/transcript/1234",
|
||||
json={"text": "Test transcription text", "id": "1234"},
|
||||
status=200,
|
||||
)
|
||||
|
||||
loader = AssemblyAIAudioLoaderById(
|
||||
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
|
||||
)
|
||||
docs = loader.load()
|
||||
assert len(docs) == 1
|
||||
assert docs[0].page_content == "Test transcription text"
|
||||
assert docs[0].metadata == {"text": "Test transcription text", "id": "1234"}
|
||||
|
||||
|
||||
@pytest.mark.requires("assemblyai")
|
||||
@responses.activate
|
||||
def test_transcription_error_by_id() -> None:
|
||||
responses.add(
|
||||
responses.GET,
|
||||
"https://api.assemblyai.com/v2/transcript/1234",
|
||||
status=404,
|
||||
)
|
||||
loader = AssemblyAIAudioLoaderById(
|
||||
transcript_id="1234", api_key="api_key", transcript_format=TranscriptFormat.TEXT
|
||||
)
|
||||
with pytest.raises(HTTPError):
|
||||
loader.load()
|
||||
|
@ -20,6 +20,7 @@ EXPECTED_ALL = [
|
||||
"ApifyDatasetLoader",
|
||||
"ArcGISLoader",
|
||||
"ArxivLoader",
|
||||
"AssemblyAIAudioLoaderById",
|
||||
"AssemblyAIAudioTranscriptLoader",
|
||||
"AstraDBLoader",
|
||||
"AsyncHtmlLoader",
|
||||
|
Loading…
Reference in New Issue
Block a user