mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 02:33:05 +00:00
Add new parameter forced_decoder_ids to OpenAIWhisperParserLocal + small bug fix (#8793)
- Description: new parameter forced_decoder_ids for OpenAIWhisperParserLocal to force input language, and enable optional translate mode. Usage example: processor = WhisperProcessor.from_pretrained("openai/whisper-medium") forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="transcribe") #forced_decoder_ids = processor.get_decoder_prompt_ids(language="french", task="translate") loader = GenericLoader(YoutubeAudioLoader(urls, save_dir), OpenAIWhisperParserLocal(lang_model="openai/whisper-medium",forced_decoder_ids=forced_decoder_ids)) - Issue #8792 - Tag maintainer: @rlancemartin, @eyurtsev --------- Co-authored-by: idcore <eugene.novozhilov@gmail.com>
This commit is contained in:
parent
40079d4936
commit
fe78aff1f2
@ -1,10 +1,13 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Iterator, Optional
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
|
||||
from langchain.document_loaders.base import BaseBlobParser
|
||||
from langchain.document_loaders.blob_loaders import Blob
|
||||
from langchain.schema import Document
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIWhisperParser(BaseBlobParser):
|
||||
"""Transcribe and parse audio files.
|
||||
@ -77,12 +80,31 @@ class OpenAIWhisperParser(BaseBlobParser):
|
||||
|
||||
class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
"""Transcribe and parse audio files.
|
||||
Audio transcription is with OpenAI Whisper model locally from transformers
|
||||
NOTE: By default uses the gpu if available, if you want to use cpu,
|
||||
please set device = "cpu"
|
||||
Audio transcription with OpenAI Whisper model locally from transformers
|
||||
Parameters:
|
||||
device - device to use
|
||||
NOTE: By default uses the gpu if available,
|
||||
if you want to use cpu, please set device = "cpu"
|
||||
lang_model - whisper model to use, for example "openai/whisper-medium"
|
||||
forced_decoder_ids - id states for decoder in multilanguage model,
|
||||
usage example:
|
||||
from transformers import WhisperProcessor
|
||||
processor = WhisperProcessor.from_pretrained("openai/whisper-medium")
|
||||
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
|
||||
task="transcribe")
|
||||
forced_decoder_ids = WhisperProcessor.get_decoder_prompt_ids(language="french",
|
||||
task="translate")
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, device: str = "0", lang_model: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
device: str = "0",
|
||||
lang_model: Optional[str] = None,
|
||||
forced_decoder_ids: Optional[Tuple[Dict]] = None,
|
||||
):
|
||||
try:
|
||||
from transformers import pipeline
|
||||
except ImportError:
|
||||
@ -136,10 +158,19 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
# load model for inference
|
||||
self.pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model="openai/whisper-medium",
|
||||
model=self.lang_model,
|
||||
chunk_length_s=30,
|
||||
device=self.device,
|
||||
)
|
||||
if forced_decoder_ids is not None:
|
||||
try:
|
||||
self.pipe.model.config.forced_decoder_ids = forced_decoder_ids
|
||||
except Exception as exception_text:
|
||||
logger.info(
|
||||
"Unable to set forced_decoder_ids parameter for whisper model"
|
||||
f"Text of exception: {exception_text}"
|
||||
"Therefore whisper model will use default mode for decoder"
|
||||
)
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
|
Loading…
Reference in New Issue
Block a user