community[patch]: Refactor OpenAIWhisperParserLocal (#15150)

This PR addresses an issue in OpenAIWhisperParserLocal where requesting
CUDA without availability leads to an AttributeError #15143

Changes:

- Refactored Logic for CUDA Availability: The initialization now
includes a check for CUDA availability. If CUDA is not available, the
code falls back to using the CPU. This ensures seamless operation
without manual intervention.
- Parameterizing Batch Size and Chunk Size: The batch_size and
chunk_size are now configurable parameters, offering greater flexibility
and optimization options based on the specific requirements of the use
case.

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
andrijdavid 2024-01-15 21:29:14 +01:00 committed by GitHub
parent 5cf06db3b3
commit d196646811
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,7 +64,7 @@ class OpenAIWhisperParser(BaseBlobParser):
file_obj.name = f"part_{split_number}.mp3" file_obj.name = f"part_{split_number}.mp3"
# Transcribe # Transcribe
print(f"Transcribing part {split_number+1}!") print(f"Transcribing part {split_number + 1}!")
attempts = 0 attempts = 0
while attempts < 3: while attempts < 3:
try: try:
@ -116,6 +116,8 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
self, self,
device: str = "0", device: str = "0",
lang_model: Optional[str] = None, lang_model: Optional[str] = None,
batch_size: int = 8,
chunk_length: int = 30,
forced_decoder_ids: Optional[Tuple[Dict]] = None, forced_decoder_ids: Optional[Tuple[Dict]] = None,
): ):
"""Initialize the parser. """Initialize the parser.
@ -126,6 +128,10 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
Defaults to None. Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model. forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None. Defaults to None.
batch_size: batch size used for decoding
Defaults to 8.
chunk_length: chunk length used during inference.
Defaults to 30s.
""" """
try: try:
from transformers import pipeline from transformers import pipeline
@ -141,47 +147,37 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
"torch package not found, please install it with " "`pip install torch`" "torch package not found, please install it with " "`pip install torch`"
) )
# set device, cpu by default check if there is a GPU available # Determine the device to use
if device == "cpu": if device == "cpu":
self.device = "cpu" self.device = "cpu"
if lang_model is not None:
self.lang_model = lang_model
print("WARNING! Model override. Using model: ", self.lang_model)
else:
# unless overridden, use the small base model on cpu
self.lang_model = "openai/whisper-base"
else: else:
if torch.cuda.is_available(): self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device = "cuda:0"
# check GPU memory and select automatically the model
mem = torch.cuda.get_device_properties(self.device).total_memory / (
1024**2
)
if mem < 5000:
rec_model = "openai/whisper-base"
elif mem < 7000:
rec_model = "openai/whisper-small"
elif mem < 12000:
rec_model = "openai/whisper-medium"
else:
rec_model = "openai/whisper-large"
# check if model is overridden if self.device == "cpu":
if lang_model is not None: default_model = "openai/whisper-base"
self.lang_model = lang_model self.lang_model = lang_model if lang_model else default_model
print("WARNING! Model override. Might not fit in your GPU") else:
else: # Set the language model based on the device and available memory
self.lang_model = rec_model mem = torch.cuda.get_device_properties(self.device).total_memory / (1024**2)
if mem < 5000:
rec_model = "openai/whisper-base"
elif mem < 7000:
rec_model = "openai/whisper-small"
elif mem < 12000:
rec_model = "openai/whisper-medium"
else: else:
"cpu" rec_model = "openai/whisper-large"
self.lang_model = lang_model if lang_model else rec_model
print("Using the following model: ", self.lang_model) print("Using the following model: ", self.lang_model)
self.batch_size = batch_size
# load model for inference # load model for inference
self.pipe = pipeline( self.pipe = pipeline(
"automatic-speech-recognition", "automatic-speech-recognition",
model=self.lang_model, model=self.lang_model,
chunk_length_s=30, chunk_length_s=chunk_length,
device=self.device, device=self.device,
) )
if forced_decoder_ids is not None: if forced_decoder_ids is not None:
@ -224,7 +220,7 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
y, sr = librosa.load(file_obj, sr=16000) y, sr = librosa.load(file_obj, sr=16000)
prediction = self.pipe(y.copy(), batch_size=8)["text"] prediction = self.pipe(y.copy(), batch_size=self.batch_size)["text"]
yield Document( yield Document(
page_content=prediction, page_content=prediction,