mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
community[patch]: Refactor OpenAIWhisperParserLocal (#15150)
This PR addresses an issue in OpenAIWhisperParserLocal where requesting CUDA without availability leads to an AttributeError #15143 Changes: - Refactored Logic for CUDA Availability: The initialization now includes a check for CUDA availability. If CUDA is not available, the code falls back to using the CPU. This ensures seamless operation without manual intervention. - Parameterizing Batch Size and Chunk Size: The batch_size and chunk_size are now configurable parameters, offering greater flexibility and optimization options based on the specific requirements of the use case. --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
5cf06db3b3
commit
d196646811
@ -64,7 +64,7 @@ class OpenAIWhisperParser(BaseBlobParser):
|
||||
file_obj.name = f"part_{split_number}.mp3"
|
||||
|
||||
# Transcribe
|
||||
print(f"Transcribing part {split_number+1}!")
|
||||
print(f"Transcribing part {split_number + 1}!")
|
||||
attempts = 0
|
||||
while attempts < 3:
|
||||
try:
|
||||
@ -116,6 +116,8 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
self,
|
||||
device: str = "0",
|
||||
lang_model: Optional[str] = None,
|
||||
batch_size: int = 8,
|
||||
chunk_length: int = 30,
|
||||
forced_decoder_ids: Optional[Tuple[Dict]] = None,
|
||||
):
|
||||
"""Initialize the parser.
|
||||
@ -126,6 +128,10 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
Defaults to None.
|
||||
forced_decoder_ids: id states for decoder in a multilanguage model.
|
||||
Defaults to None.
|
||||
batch_size: batch size used for decoding
|
||||
Defaults to 8.
|
||||
chunk_length: chunk length used during inference.
|
||||
Defaults to 30s.
|
||||
"""
|
||||
try:
|
||||
from transformers import pipeline
|
||||
@ -141,47 +147,37 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
"torch package not found, please install it with " "`pip install torch`"
|
||||
)
|
||||
|
||||
# set device, cpu by default check if there is a GPU available
|
||||
# Determine the device to use
|
||||
if device == "cpu":
|
||||
self.device = "cpu"
|
||||
if lang_model is not None:
|
||||
self.lang_model = lang_model
|
||||
print("WARNING! Model override. Using model: ", self.lang_model)
|
||||
else:
|
||||
# unless overridden, use the small base model on cpu
|
||||
self.lang_model = "openai/whisper-base"
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda:0"
|
||||
# check GPU memory and select automatically the model
|
||||
mem = torch.cuda.get_device_properties(self.device).total_memory / (
|
||||
1024**2
|
||||
)
|
||||
if mem < 5000:
|
||||
rec_model = "openai/whisper-base"
|
||||
elif mem < 7000:
|
||||
rec_model = "openai/whisper-small"
|
||||
elif mem < 12000:
|
||||
rec_model = "openai/whisper-medium"
|
||||
else:
|
||||
rec_model = "openai/whisper-large"
|
||||
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# check if model is overridden
|
||||
if lang_model is not None:
|
||||
self.lang_model = lang_model
|
||||
print("WARNING! Model override. Might not fit in your GPU")
|
||||
else:
|
||||
self.lang_model = rec_model
|
||||
if self.device == "cpu":
|
||||
default_model = "openai/whisper-base"
|
||||
self.lang_model = lang_model if lang_model else default_model
|
||||
else:
|
||||
# Set the language model based on the device and available memory
|
||||
mem = torch.cuda.get_device_properties(self.device).total_memory / (1024**2)
|
||||
if mem < 5000:
|
||||
rec_model = "openai/whisper-base"
|
||||
elif mem < 7000:
|
||||
rec_model = "openai/whisper-small"
|
||||
elif mem < 12000:
|
||||
rec_model = "openai/whisper-medium"
|
||||
else:
|
||||
"cpu"
|
||||
rec_model = "openai/whisper-large"
|
||||
self.lang_model = lang_model if lang_model else rec_model
|
||||
|
||||
print("Using the following model: ", self.lang_model)
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
# load model for inference
|
||||
self.pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=self.lang_model,
|
||||
chunk_length_s=30,
|
||||
chunk_length_s=chunk_length,
|
||||
device=self.device,
|
||||
)
|
||||
if forced_decoder_ids is not None:
|
||||
@ -224,7 +220,7 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
|
||||
|
||||
y, sr = librosa.load(file_obj, sr=16000)
|
||||
|
||||
prediction = self.pipe(y.copy(), batch_size=8)["text"]
|
||||
prediction = self.pipe(y.copy(), batch_size=self.batch_size)["text"]
|
||||
|
||||
yield Document(
|
||||
page_content=prediction,
|
||||
|
Loading…
Reference in New Issue
Block a user