From d19664681121def24aea6dce1ce9a7c8f7c1c466 Mon Sep 17 00:00:00 2001 From: andrijdavid Date: Mon, 15 Jan 2024 21:29:14 +0100 Subject: [PATCH] community[patch]: Refactor OpenAIWhisperParserLocal (#15150) This PR addresses an issue in OpenAIWhisperParserLocal where requesting CUDA without availability leads to an AttributeError #15143 Changes: - Refactored Logic for CUDA Availability: The initialization now includes a check for CUDA availability. If CUDA is not available, the code falls back to using the CPU. This ensures seamless operation without manual intervention. - Parameterizing Batch Size and Chunk Size: The batch_size and chunk_size are now configurable parameters, offering greater flexibility and optimization options based on the specific requirements of the use case. --------- Co-authored-by: Harrison Chase --- .../document_loaders/parsers/audio.py | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/libs/community/langchain_community/document_loaders/parsers/audio.py b/libs/community/langchain_community/document_loaders/parsers/audio.py index ab54c67ed37..3b96f9860c5 100644 --- a/libs/community/langchain_community/document_loaders/parsers/audio.py +++ b/libs/community/langchain_community/document_loaders/parsers/audio.py @@ -64,7 +64,7 @@ class OpenAIWhisperParser(BaseBlobParser): file_obj.name = f"part_{split_number}.mp3" # Transcribe - print(f"Transcribing part {split_number+1}!") + print(f"Transcribing part {split_number + 1}!") attempts = 0 while attempts < 3: try: @@ -116,6 +116,8 @@ class OpenAIWhisperParserLocal(BaseBlobParser): self, device: str = "0", lang_model: Optional[str] = None, + batch_size: int = 8, + chunk_length: int = 30, forced_decoder_ids: Optional[Tuple[Dict]] = None, ): """Initialize the parser. @@ -126,6 +128,10 @@ class OpenAIWhisperParserLocal(BaseBlobParser): Defaults to None. forced_decoder_ids: id states for decoder in a multilanguage model. Defaults to None. + batch_size: batch size used for decoding + Defaults to 8. + chunk_length: chunk length used during inference. + Defaults to 30s. """ try: from transformers import pipeline @@ -141,47 +147,37 @@ class OpenAIWhisperParserLocal(BaseBlobParser): "torch package not found, please install it with " "`pip install torch`" ) - # set device, cpu by default check if there is a GPU available + # Determine the device to use if device == "cpu": self.device = "cpu" - if lang_model is not None: - self.lang_model = lang_model - print("WARNING! Model override. Using model: ", self.lang_model) - else: - # unless overridden, use the small base model on cpu - self.lang_model = "openai/whisper-base" else: - if torch.cuda.is_available(): - self.device = "cuda:0" - # check GPU memory and select automatically the model - mem = torch.cuda.get_device_properties(self.device).total_memory / ( - 1024**2 - ) - if mem < 5000: - rec_model = "openai/whisper-base" - elif mem < 7000: - rec_model = "openai/whisper-small" - elif mem < 12000: - rec_model = "openai/whisper-medium" - else: - rec_model = "openai/whisper-large" + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" - # check if model is overridden - if lang_model is not None: - self.lang_model = lang_model - print("WARNING! Model override. Might not fit in your GPU") - else: - self.lang_model = rec_model + if self.device == "cpu": + default_model = "openai/whisper-base" + self.lang_model = lang_model if lang_model else default_model + else: + # Set the language model based on the device and available memory + mem = torch.cuda.get_device_properties(self.device).total_memory / (1024**2) + if mem < 5000: + rec_model = "openai/whisper-base" + elif mem < 7000: + rec_model = "openai/whisper-small" + elif mem < 12000: + rec_model = "openai/whisper-medium" else: - "cpu" + rec_model = "openai/whisper-large" + self.lang_model = lang_model if lang_model else rec_model print("Using the following model: ", self.lang_model) + self.batch_size = batch_size + # load model for inference self.pipe = pipeline( "automatic-speech-recognition", model=self.lang_model, - chunk_length_s=30, + chunk_length_s=chunk_length, device=self.device, ) if forced_decoder_ids is not None: @@ -224,7 +220,7 @@ class OpenAIWhisperParserLocal(BaseBlobParser): y, sr = librosa.load(file_obj, sr=16000) - prediction = self.pipe(y.copy(), batch_size=8)["text"] + prediction = self.pipe(y.copy(), batch_size=self.batch_size)["text"] yield Document( page_content=prediction,