mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 15:03:21 +00:00
community: Add arguments to whisper parser (#20378)
**Description:** Added a few additional arguments to the whisper parser, which can be consumed by the underlying API. The prompt is especially important to fine-tune transcriptions. --------- Co-authored-by: Roi Perlman <roi@fivesigmalabs.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Iterator, Optional, Tuple
|
from typing import Any, Dict, Iterator, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
@@ -31,12 +31,32 @@ class OpenAIWhisperParser(BaseBlobParser):
|
|||||||
*,
|
*,
|
||||||
chunk_duration_threshold: float = 0.1,
|
chunk_duration_threshold: float = 0.1,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
|
language: Union[str, None] = None,
|
||||||
|
prompt: Union[str, None] = None,
|
||||||
|
response_format: Union[
|
||||||
|
Literal["json", "text", "srt", "verbose_json", "vtt"], None
|
||||||
|
] = None,
|
||||||
|
temperature: Union[float, None] = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.chunk_duration_threshold = chunk_duration_threshold
|
self.chunk_duration_threshold = chunk_duration_threshold
|
||||||
self.base_url = (
|
self.base_url = (
|
||||||
base_url if base_url is not None else os.environ.get("OPENAI_API_BASE")
|
base_url if base_url is not None else os.environ.get("OPENAI_API_BASE")
|
||||||
)
|
)
|
||||||
|
self.language = language
|
||||||
|
self.prompt = prompt
|
||||||
|
self.response_format = response_format
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _create_params(self) -> Dict[str, Any]:
|
||||||
|
params = {
|
||||||
|
"language": self.language,
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
}
|
||||||
|
return {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazily parse the blob."""
|
"""Lazily parse the blob."""
|
||||||
@@ -95,7 +115,7 @@ class OpenAIWhisperParser(BaseBlobParser):
|
|||||||
try:
|
try:
|
||||||
if is_openai_v1():
|
if is_openai_v1():
|
||||||
transcript = client.audio.transcriptions.create(
|
transcript = client.audio.transcriptions.create(
|
||||||
model="whisper-1", file=file_obj
|
model="whisper-1", file=file_obj, **self._create_params
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transcript = openai.Audio.transcribe("whisper-1", file_obj)
|
transcript = openai.Audio.transcribe("whisper-1", file_obj)
|
||||||
|
Reference in New Issue
Block a user