mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 06:23:20 +00:00
community: Add arguments to whisper parser (#20378)
**Description:** Added a few additional arguments to the whisper parser, which can be consumed by the underlying API. The prompt is especially important to fine-tune transcriptions. --------- Co-authored-by: Roi Perlman <roi@fivesigmalabs.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict, Iterator, Optional, Tuple
|
||||
from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
@@ -31,12 +31,32 @@ class OpenAIWhisperParser(BaseBlobParser):
|
||||
*,
|
||||
chunk_duration_threshold: float = 0.1,
|
||||
base_url: Optional[str] = None,
|
||||
language: Union[str, None] = None,
|
||||
prompt: Union[str, None] = None,
|
||||
response_format: Union[
|
||||
Literal["json", "text", "srt", "verbose_json", "vtt"], None
|
||||
] = None,
|
||||
temperature: Union[float, None] = None,
|
||||
):
|
||||
self.api_key = api_key
|
||||
self.chunk_duration_threshold = chunk_duration_threshold
|
||||
self.base_url = (
|
||||
base_url if base_url is not None else os.environ.get("OPENAI_API_BASE")
|
||||
)
|
||||
self.language = language
|
||||
self.prompt = prompt
|
||||
self.response_format = response_format
|
||||
self.temperature = temperature
|
||||
|
||||
@property
|
||||
def _create_params(self) -> Dict[str, Any]:
|
||||
params = {
|
||||
"language": self.language,
|
||||
"prompt": self.prompt,
|
||||
"response_format": self.response_format,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazily parse the blob."""
|
||||
@@ -95,7 +115,7 @@ class OpenAIWhisperParser(BaseBlobParser):
|
||||
try:
|
||||
if is_openai_v1():
|
||||
transcript = client.audio.transcriptions.create(
|
||||
model="whisper-1", file=file_obj
|
||||
model="whisper-1", file=file_obj, **self._create_params
|
||||
)
|
||||
else:
|
||||
transcript = openai.Audio.transcribe("whisper-1", file_obj)
|
||||
|
Reference in New Issue
Block a user