mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-17 23:41:46 +00:00
community: Added new Utility runnables for NVIDIA Riva. (#15966)
**Please tag this issue with `nvidia_genai`** - **Description:** Added new Runnables for integration NVIDIA Riva into LCEL chains for Automatic Speech Recognition (ASR) and Text To Speech (TTS). - **Issue:** N/A - **Dependencies:** To use these runnables, the NVIDIA Riva client libraries are required. It they are not installed, an error will be raised instructing how to install them. The Runnables can be safely imported without the riva client libraries. - **Twitter handle:** N/A All of the Riva Runnables are inside a single folder in the Utilities module. In this folder are four files: - common.py - Contains all code that is common to both TTS and ASR - stream.py - Contains a class representing an audio stream that allows the end user to put data into the stream like a queue. - asr.py - Contains the RivaASR runnable - tts.py - Contains the RivaTTS runnable The following Python function is an example of creating a chain that makes use of both of these Runnables: ```python def create( config: Configuration, audio_encoding: RivaAudioEncoding, sample_rate: int, audio_channels: int = 1, ) -> Runnable[ASRInputType, TTSOutputType]: """Create a new instance of the chain.""" _LOGGER.info("Instantiating the chain.") # create the riva asr client riva_asr = RivaASR( url=str(config.riva_asr.service.url), ssl_cert=config.riva_asr.service.ssl_cert, encoding=audio_encoding, audio_channel_count=audio_channels, sample_rate_hertz=sample_rate, profanity_filter=config.riva_asr.profanity_filter, enable_automatic_punctuation=config.riva_asr.enable_automatic_punctuation, language_code=config.riva_asr.language_code, ) # create the prompt template prompt = PromptTemplate.from_template("{user_input}") # model = ChatOpenAI() model = ChatNVIDIA(model="mixtral_8x7b") # type: ignore # create the riva tts client riva_tts = RivaTTS( url=str(config.riva_asr.service.url), ssl_cert=config.riva_asr.service.ssl_cert, output_directory=config.riva_tts.output_directory, language_code=config.riva_tts.language_code, voice_name=config.riva_tts.voice_name, ) # construct and return the chain return {"user_input": riva_asr} | prompt | model | riva_tts # type: ignore ``` The following code is an example of creating a new audio stream for Riva: ```python input_stream = AudioStream(maxsize=1000) # Send bytes into the stream for chunk in audio_chunks: await input_stream.aput(chunk) input_stream.close() ``` The following code is an example of how to execute the chain with RivaASR and RivaTTS ```python output_stream = asyncio.Queue() while not input_stream.complete: async for chunk in chain.astream(input_stream): output_stream.put(chunk) ``` Everything should be async safe and thread safe. Audio data can be put into the input stream while the chain is running without interruptions. --------- Co-authored-by: Hayden Wolff <hwolff@nvidia.com> Co-authored-by: Hayden Wolff <hwolff@Haydens-Laptop.local> Co-authored-by: Hayden Wolff <haydenwolff99@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
@@ -24,6 +24,9 @@ EXPECTED_ALL = [
|
||||
"MaxComputeAPIWrapper",
|
||||
"MetaphorSearchAPIWrapper",
|
||||
"NasaAPIWrapper",
|
||||
"NVIDIARivaASR",
|
||||
"NVIDIARivaTTS",
|
||||
"NVIDIARivaStream",
|
||||
"OpenWeatherMapAPIWrapper",
|
||||
"OutlineAPIWrapper",
|
||||
"Portkey",
|
||||
|
@@ -0,0 +1,177 @@
|
||||
"""Unit tests to verify function of the Riva ASR implementation."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.utilities.nvidia_riva import (
|
||||
AudioStream,
|
||||
RivaASR,
|
||||
RivaAudioEncoding,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import riva.client
|
||||
import riva.client.proto.riva_asr_pb2 as rasr
|
||||
|
||||
AUDIO_DATA_MOCK = [
|
||||
b"This",
|
||||
b"is",
|
||||
b"a",
|
||||
b"test.",
|
||||
b"_",
|
||||
b"Hello.",
|
||||
b"World",
|
||||
]
|
||||
AUDIO_TEXT_MOCK = b" ".join(AUDIO_DATA_MOCK).decode().strip().split("_")
|
||||
|
||||
SVC_URI = "not-a-url.asdf:9999"
|
||||
SVC_USE_SSL = True
|
||||
CONFIG = {
|
||||
"audio_channel_count": 9,
|
||||
"profanity_filter": False,
|
||||
"enable_automatic_punctuation": False,
|
||||
"url": f"{'https' if SVC_USE_SSL else 'http'}://{SVC_URI}",
|
||||
"ssl_cert": "/dev/null",
|
||||
"encoding": RivaAudioEncoding.ALAW,
|
||||
"language_code": "not-a-language",
|
||||
"sample_rate_hertz": 5,
|
||||
}
|
||||
|
||||
|
||||
def response_generator(
|
||||
transcript: str = "",
|
||||
empty: bool = False,
|
||||
final: bool = False,
|
||||
alternatives: bool = True,
|
||||
) -> "rasr.StreamingRecognizeResponse":
|
||||
"""Create a pseudo streaming response."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import riva.client.proto.riva_asr_pb2 as rasr
|
||||
|
||||
if empty:
|
||||
return rasr.StreamingRecognizeResponse()
|
||||
|
||||
if not alternatives:
|
||||
return rasr.StreamingRecognizeResponse(
|
||||
results=[
|
||||
rasr.StreamingRecognitionResult(
|
||||
is_final=final,
|
||||
alternatives=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return rasr.StreamingRecognizeResponse(
|
||||
results=[
|
||||
rasr.StreamingRecognitionResult(
|
||||
is_final=final,
|
||||
alternatives=[
|
||||
rasr.SpeechRecognitionAlternative(transcript=transcript.strip())
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def streaming_recognize_mock(
|
||||
generator: Generator["rasr.StreamingRecognizeRequest", None, None], **_: Any
|
||||
) -> Generator["rasr.StreamingRecognizeResponse", None, None]:
|
||||
"""A mock function to fake a streaming call to Riva."""
|
||||
yield response_generator(empty=True)
|
||||
yield response_generator(alternatives=False)
|
||||
|
||||
output_transcript = ""
|
||||
for streaming_requests in generator:
|
||||
input_bytes = streaming_requests.audio_content.decode()
|
||||
|
||||
final = input_bytes == "_"
|
||||
if final:
|
||||
input_bytes = ""
|
||||
|
||||
output_transcript += input_bytes + " "
|
||||
|
||||
yield response_generator(final=final, transcript=output_transcript)
|
||||
if final:
|
||||
output_transcript = ""
|
||||
|
||||
yield response_generator(final=True, transcript=output_transcript)
|
||||
|
||||
|
||||
def riva_asr_stub_init_patch(
|
||||
self: "riva.client.proto.riva_asr_pb2_grpc.RivaSpeechRecognitionStub", _: Any
|
||||
) -> None:
|
||||
"""Patch for the Riva asr library."""
|
||||
self.StreamingRecognize = streaming_recognize_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def asr() -> RivaASR:
|
||||
"""Initialize a copy of the runnable."""
|
||||
return RivaASR(**CONFIG)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stream() -> AudioStream:
|
||||
"""Initialize and populate a sample audio stream."""
|
||||
s = AudioStream()
|
||||
for val in AUDIO_DATA_MOCK:
|
||||
s.put(val)
|
||||
s.close()
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_init(asr: RivaASR) -> None:
|
||||
"""Test that ASR accepts valid arguments."""
|
||||
for key, expected_val in CONFIG.items():
|
||||
assert getattr(asr, key, None) == expected_val
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_init_defaults() -> None:
|
||||
"""Ensure the runnable can be loaded with no arguments."""
|
||||
_ = RivaASR()
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_config(asr: RivaASR) -> None:
|
||||
"""Verify the Riva config is properly assembled."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import riva.client.proto.riva_asr_pb2 as rasr
|
||||
|
||||
expected = rasr.StreamingRecognitionConfig(
|
||||
interim_results=True,
|
||||
config=rasr.RecognitionConfig(
|
||||
encoding=CONFIG["encoding"],
|
||||
sample_rate_hertz=CONFIG["sample_rate_hertz"],
|
||||
audio_channel_count=CONFIG["audio_channel_count"],
|
||||
max_alternatives=1,
|
||||
profanity_filter=CONFIG["profanity_filter"],
|
||||
enable_automatic_punctuation=CONFIG["enable_automatic_punctuation"],
|
||||
language_code=CONFIG["language_code"],
|
||||
),
|
||||
)
|
||||
assert asr.config == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_get_service(asr: RivaASR) -> None:
|
||||
"""Test generating an asr service class."""
|
||||
svc = asr._get_service()
|
||||
assert str(svc.auth.ssl_cert) == CONFIG["ssl_cert"]
|
||||
assert svc.auth.use_ssl == SVC_USE_SSL
|
||||
assert svc.auth.uri == SVC_URI
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
@patch(
|
||||
"riva.client.proto.riva_asr_pb2_grpc.RivaSpeechRecognitionStub.__init__",
|
||||
riva_asr_stub_init_patch,
|
||||
)
|
||||
def test_invoke(asr: RivaASR, stream: AudioStream) -> None:
|
||||
"""Test the invoke method."""
|
||||
got = asr.invoke(stream)
|
||||
expected = " ".join([s.strip() for s in AUDIO_TEXT_MOCK]).strip()
|
||||
assert got == expected
|
@@ -0,0 +1,153 @@
|
||||
"""Unit tests to verify function of the Riva TTS implementation."""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.utilities.nvidia_riva import RivaAudioEncoding, RivaTTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import riva.client
|
||||
import riva.client.proto.riva_tts_pb2 as rtts
|
||||
|
||||
AUDIO_TEXT_MOCK = ["This is a test.", "Hello world"]
|
||||
AUDIO_DATA_MOCK = [s.encode() for s in AUDIO_TEXT_MOCK]
|
||||
|
||||
SVC_URI = "not-a-url.asdf:9999"
|
||||
SVC_USE_SSL = True
|
||||
CONFIG = {
|
||||
"voice_name": "English-Test",
|
||||
"output_directory": None,
|
||||
"url": f"{'https' if SVC_USE_SSL else 'http'}://{SVC_URI}",
|
||||
"ssl_cert": "/dev/null",
|
||||
"encoding": RivaAudioEncoding.ALAW,
|
||||
"language_code": "not-a-language",
|
||||
"sample_rate_hertz": 5,
|
||||
}
|
||||
|
||||
|
||||
def synthesize_online_mock(
|
||||
request: "rtts.SynthesizeSpeechRequest", **_: Any
|
||||
) -> Generator["rtts.SynthesizeSpeechResponse", None, None]:
|
||||
"""A mock function to fake a streaming call to Riva."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
import riva.client.proto.riva_tts_pb2 as rtts
|
||||
|
||||
yield rtts.SynthesizeSpeechResponse(
|
||||
audio=f"[{request.language_code},{request.encoding},{request.sample_rate_hz},{request.voice_name}]".encode()
|
||||
)
|
||||
yield rtts.SynthesizeSpeechResponse(audio=request.text.strip().encode())
|
||||
|
||||
|
||||
def riva_tts_stub_init_patch(
|
||||
self: "riva.client.proto.riva_tts_pb2_grpc.RivaSpeechSynthesisStub", _: Any
|
||||
) -> None:
|
||||
"""Patch for the Riva TTS library."""
|
||||
self.SynthesizeOnline = synthesize_online_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tts() -> RivaTTS:
|
||||
"""Initialize a copy of the runnable."""
|
||||
return RivaTTS(**CONFIG)
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_init(tts: RivaTTS) -> None:
|
||||
"""Test that ASR accepts valid arguments."""
|
||||
for key, expected_val in CONFIG.items():
|
||||
assert getattr(tts, key, None) == expected_val
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_init_defaults() -> None:
|
||||
"""Ensure the runnable can be loaded with no arguments."""
|
||||
_ = RivaTTS()
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
def test_get_service(tts: RivaTTS) -> None:
|
||||
"""Test the get service method."""
|
||||
svc = tts._get_service()
|
||||
assert str(svc.auth.ssl_cert) == CONFIG["ssl_cert"]
|
||||
assert svc.auth.use_ssl == SVC_USE_SSL
|
||||
assert svc.auth.uri == SVC_URI
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
@patch(
|
||||
"riva.client.proto.riva_tts_pb2_grpc.RivaSpeechSynthesisStub.__init__",
|
||||
riva_tts_stub_init_patch,
|
||||
)
|
||||
def test_invoke(tts: RivaTTS) -> None:
|
||||
"""Test the invoke method."""
|
||||
encoding = cast(RivaAudioEncoding, CONFIG["encoding"]).riva_pb2
|
||||
audio_synth_config = (
|
||||
f"[{CONFIG['language_code']},"
|
||||
f"{encoding},"
|
||||
f"{CONFIG['sample_rate_hertz']},"
|
||||
f"{CONFIG['voice_name']}]"
|
||||
)
|
||||
|
||||
input = " ".join(AUDIO_TEXT_MOCK).strip()
|
||||
response = tts.invoke(input)
|
||||
expected = (audio_synth_config + audio_synth_config.join(AUDIO_TEXT_MOCK)).encode()
|
||||
assert response == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
@patch(
|
||||
"riva.client.proto.riva_tts_pb2_grpc.RivaSpeechSynthesisStub.__init__",
|
||||
riva_tts_stub_init_patch,
|
||||
)
|
||||
def test_transform(tts: RivaTTS) -> None:
|
||||
"""Test the transform method."""
|
||||
encoding = cast(RivaAudioEncoding, CONFIG["encoding"]).riva_pb2
|
||||
audio_synth_config = (
|
||||
f"[{CONFIG['language_code']},"
|
||||
f"{encoding},"
|
||||
f"{CONFIG['sample_rate_hertz']},"
|
||||
f"{CONFIG['voice_name']}]"
|
||||
)
|
||||
expected = (audio_synth_config + audio_synth_config.join(AUDIO_TEXT_MOCK)).encode()
|
||||
for idx, response in enumerate(tts.transform(iter(AUDIO_TEXT_MOCK))):
|
||||
if idx % 2:
|
||||
# odd indices will return the mocked data
|
||||
expected = AUDIO_DATA_MOCK[int((idx - 1) / 2)]
|
||||
else:
|
||||
# even indices will return the request config
|
||||
expected = audio_synth_config.encode()
|
||||
assert response == expected
|
||||
|
||||
|
||||
@pytest.mark.requires("riva.client")
|
||||
@patch(
|
||||
"riva.client.proto.riva_tts_pb2_grpc.RivaSpeechSynthesisStub.__init__",
|
||||
riva_tts_stub_init_patch,
|
||||
)
|
||||
async def test_atransform(tts: RivaTTS) -> None:
|
||||
"""Test the transform method."""
|
||||
encoding = cast(RivaAudioEncoding, CONFIG["encoding"]).riva_pb2
|
||||
audio_synth_config = (
|
||||
f"[{CONFIG['language_code']},"
|
||||
f"{encoding},"
|
||||
f"{CONFIG['sample_rate_hertz']},"
|
||||
f"{CONFIG['voice_name']}]"
|
||||
)
|
||||
expected = (audio_synth_config + audio_synth_config.join(AUDIO_TEXT_MOCK)).encode()
|
||||
idx = 0
|
||||
|
||||
async def _fake_async_iterable() -> AsyncGenerator[str, None]:
|
||||
for val in AUDIO_TEXT_MOCK:
|
||||
yield val
|
||||
|
||||
async for response in tts.atransform(_fake_async_iterable()):
|
||||
if idx % 2:
|
||||
# odd indices will return the mocked data
|
||||
expected = AUDIO_DATA_MOCK[int((idx - 1) / 2)]
|
||||
else:
|
||||
# even indices will return the request config
|
||||
expected = audio_synth_config.encode()
|
||||
assert response == expected
|
||||
idx += 1
|
Reference in New Issue
Block a user