mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 18:53:02 +00:00
Add new parameter forced_decoder_ids to OpenAIWhisperParserLocal + small bug fix (#8793)
- Description: new parameter forced_decoder_ids for OpenAIWhisperParserLocal to force input language, and enable optional translate mode. Usage example: processor = WhisperProcessor.from_pretrained("openai/whisper-medium") forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="transcribe") #forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="translate") loader = GenericLoader(YoutubeAudioLoader(urls, save_dir), OpenAIWhisperParserLocal(lang_model="openai/whisper-medium",forced_decoder_ids=forced_decoder_ids)) - Issue #8792 - Tag maintainer: @rlancemartin, @eyurtsev --------- Co-authored-by: idcore <eugene.novozhilov@gmail.com>
This commit is contained in:
parent
40079d4936
commit
fe78aff1f2
@ -1,10 +1,13 @@
|
|||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Iterator, Optional
|
from typing import Dict, Iterator, Optional, Tuple
|
||||||
|
|
||||||
from langchain.document_loaders.base import BaseBlobParser
|
from langchain.document_loaders.base import BaseBlobParser
|
||||||
from langchain.document_loaders.blob_loaders import Blob
|
from langchain.document_loaders.blob_loaders import Blob
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIWhisperParser(BaseBlobParser):
|
class OpenAIWhisperParser(BaseBlobParser):
|
||||||
"""Transcribe and parse audio files.
|
"""Transcribe and parse audio files.
|
||||||
@ -77,12 +80,31 @@ class OpenAIWhisperParser(BaseBlobParser):
|
|||||||
|
|
||||||
class OpenAIWhisperParserLocal(BaseBlobParser):
|
class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||||
"""Transcribe and parse audio files.
|
"""Transcribe and parse audio files.
|
||||||
Audio transcription is with OpenAI Whisper model locally from transformers
|
Audio transcription with OpenAI Whisper model locally from transformers
|
||||||
NOTE: By default uses the gpu if available, if you want to use cpu,
|
Parameters:
|
||||||
please set device = "cpu"
|
device - device to use
|
||||||
|
NOTE: By default uses the gpu if available,
|
||||||
|
if you want to use cpu, please set device = "cpu"
|
||||||
|
lang_model - whisper model to use, for example "openai/whisper-medium"
|
||||||
|
forced_decoder_ids - id states for decoder in multilanguage model,
|
||||||
|
usage example:
|
||||||
|
from transformers import WhisperProcessor
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
||||||
|
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
|
||||||
|
task="transcribe")
|
||||||
|
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
|
||||||
|
task="translate")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, device: str = "0", lang_model: Optional[str] = None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: str = "0",
|
||||||
|
lang_model: Optional[str] = None,
|
||||||
|
forced_decoder_ids: Optional[Tuple[Dict]] = None,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -136,10 +158,19 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
|||||||
# load model for inference
|
# load model for inference
|
||||||
self.pipe = pipeline(
|
self.pipe = pipeline(
|
||||||
"automatic-speech-recognition",
|
"automatic-speech-recognition",
|
||||||
model="openai/whisper-medium",
|
model=self.lang_model,
|
||||||
chunk_length_s=30,
|
chunk_length_s=30,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
if forced_decoder_ids is not None:
|
||||||
|
try:
|
||||||
|
self.pipe.model.config.forced_decoder_ids = forced_decoder_ids
|
||||||
|
except Exception as exception_text:
|
||||||
|
logger.info(
|
||||||
|
"Unable to set forced_decoder_ids parameter for whisper model"
|
||||||
|
f"Text of exception: {exception_text}"
|
||||||
|
"Therefore whisper model will use default mode for decoder"
|
||||||
|
)
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
|
Loading…
Reference in New Issue
Block a user