{
"cells": [
{
"cell_type": "markdown",
"id": "cc6caafa",
"metadata": {
"id": "cc6caafa"
},
"source": [
"# NVIDIA Riva: ASR and TTS\n",
"\n",
"## NVIDIA Riva\n",
"[NVIDIA Riva](https://www.nvidia.com/en-us/ai-data-science/products/riva/) is a GPU-accelerated multilingual speech and translation AI software development kit for building fully customizable, real-time conversational AI pipelines—including automatic speech recognition (ASR), text-to-speech (TTS), and neural machine translation (NMT) applications—that can be deployed in clouds, in data centers, at the edge, or on embedded devices.\n",
"\n",
"The Riva Speech API server exposes a simple API for performing speech recognition, speech synthesis, and a variety of natural language processing inferences and is integrated into LangChain for ASR and TTS. See instructions on how to [setup a Riva Speech API](#3-setup) server below. \n",
"\n",
"## Integrating NVIDIA Riva to LangChain Chains\n",
"The `NVIDIARivaASR`, `NVIDIARivaTTS` utility runnables are LangChain runnables that integrate [NVIDIA Riva](https://www.nvidia.com/en-us/ai-data-science/products/riva/) into LCEL chains for Automatic Speech Recognition (ASR) and Text To Speech (TTS).\n",
"\n",
"This example goes over how to use these LangChain runnables to:\n",
"1. Accept streamed audio,\n",
"2. convert the audio to text, \n",
"3. send the text to an LLM, \n",
"4. stream a textual LLM response, and\n",
"5. convert the response to streamed human-sounding audio. "
]
},
{
"cell_type": "markdown",
"id": "b603439f",
"metadata": {},
"source": [
"## 1. NVIDIA Riva Runnables\n",
"There are 2 Riva Runnables:\n",
"\n",
"a. **RivaASR**: Converts audio bytes into text for an LLM using NVIDIA Riva. \n",
"\n",
"b. **RivaTTS**: Converts text into audio bytes using NVIDIA Riva.\n",
"\n",
"### a. RivaASR\n",
"The [**RivaASR**](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/utilities/nvidia_riva.py#L404) runnable converts audio bytes into a string for an LLM using NVIDIA Riva. \n",
"\n",
"It's useful for sending an audio stream (a message containing streaming audio) into a chain and preprocessing that audio by converting it to a string to create an LLM prompt. \n",
"\n",
"```\n",
"ASRInputType = AudioStream # the AudioStream type is a custom type for a message queue containing streaming audio\n",
"ASROutputType = str\n",
"\n",
"class RivaASR(\n",
" RivaAuthMixin,\n",
" RivaCommonConfigMixin,\n",
" RunnableSerializable[ASRInputType, ASROutputType],\n",
"):\n",
" \"\"\"A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva.\"\"\"\n",
"\n",
" name: str = \"nvidia_riva_asr\"\n",
" description: str = (\n",
" \"A Runnable for converting audio bytes to a string.\"\n",
" \"This is useful for feeding an audio stream into a chain and\"\n",
" \"preprocessing that audio to create an LLM prompt.\"\n",
" )\n",
"\n",
" # riva options\n",
" audio_channel_count: int = Field(\n",
" 1, description=\"The number of audio channels in the input audio stream.\"\n",
" )\n",
" profanity_filter: bool = Field(\n",
" True,\n",
" description=(\n",
" \"Controls whether or not Riva should attempt to filter \"\n",
" \"profanity out of the transcribed text.\"\n",
" ),\n",
" )\n",
" enable_automatic_punctuation: bool = Field(\n",
" True,\n",
" description=(\n",
" \"Controls whether Riva should attempt to correct \"\n",
" \"senetence puncuation in the transcribed text.\"\n",
" ),\n",
" )\n",
"```\n",
"\n",
"When this runnable is called on an input, it takes an input audio stream that acts as a queue and concatenates transcription as chunks are returned.After a response is fully generated, a string is returned. \n",
"* Note that since the LLM requires a full query the ASR is concatenated and not streamed in token-by-token.\n",
"\n",
"\n",
"### b. RivaTTS\n",
"The [**RivaTTS**](https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/utilities/nvidia_riva.py#L511) runnable converts text output to audio bytes. \n",
"\n",
"It's useful for processing the streamed textual response from an LLM by converting the text to audio bytes. These audio bytes sound like a natural human voice to be played back to the user. \n",
"\n",
"```\n",
"TTSInputType = Union[str, AnyMessage, PromptValue]\n",
"TTSOutputType = byte\n",
"\n",
"class RivaTTS(\n",
" RivaAuthMixin,\n",
" RivaCommonConfigMixin,\n",
" RunnableSerializable[TTSInputType, TTSOutputType],\n",
"):\n",
" \"\"\"A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva.\"\"\"\n",
"\n",
" name: str = \"nvidia_riva_tts\"\n",
" description: str = (\n",
" \"A tool for converting text to speech.\"\n",
" \"This is useful for converting LLM output into audio bytes.\"\n",
" )\n",
"\n",
" # riva options\n",
" voice_name: str = Field(\n",
" \"English-US.Female-1\",\n",
" description=(\n",
" \"The voice model in Riva to use for speech. \"\n",
" \"Pre-trained models are documented in \"\n",
" \"[the Riva documentation]\"\n",
" \"(https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html).\"\n",
" ),\n",
" )\n",
" output_directory: Optional[str] = Field(\n",
" None,\n",
" description=(\n",
" \"The directory where all audio files should be saved. \"\n",
" \"A null value indicates that wave files should not be saved. \"\n",
" \"This is useful for debugging purposes.\"\n",
" ),\n",
"```\n",
"\n",
"When this runnable is called on an input, it takes iterable text chunks and streams them into output audio bytes that are either written to a `.wav` file or played out loud."
]
},
{
"cell_type": "markdown",
"id": "f2be90a9",
"metadata": {},
"source": [
"## 2. Installation"
]
},
{
"cell_type": "markdown",
"id": "1ef87a40",
"metadata": {},
"source": [
"The NVIDIA Riva client library must be installed."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "70410821",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install --upgrade --quiet nvidia-riva-client"
]
},
{
"cell_type": "markdown",
"id": "ccff689e",
"metadata": {
"id": "ccff689e"
},
"source": [
"## 3. Setup\n",
"\n",
"**To get started with NVIDIA Riva:**\n",
"\n",
"1. Follow the Riva Quick Start setup instructions for [Local Deployment Using Quick Start Scripts](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html#local-deployment-using-quick-start-scripts)."
]
},
{
"cell_type": "markdown",
"id": "57b6741b",
"metadata": {},
"source": [
"## 4. Import and Inspect Runnables\n",
"Import the RivaASR and RivaTTS runnables and inspect their schemas to understand their fields. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2d6fa641",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"from langchain_community.utilities.nvidia_riva import (\n",
" RivaASR,\n",
" RivaTTS,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0e6dd656",
"metadata": {},
"source": [
"Let's view the schemas."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "69460762",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"title\": \"RivaASR\",\n",
" \"description\": \"A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva.\",\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\n",
" \"title\": \"Name\",\n",
" \"default\": \"nvidia_riva_asr\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"encoding\": {\n",
" \"description\": \"The encoding on the audio stream.\",\n",
" \"default\": \"LINEAR_PCM\",\n",
" \"allOf\": [\n",
" {\n",
" \"$ref\": \"#/definitions/RivaAudioEncoding\"\n",
" }\n",
" ]\n",
" },\n",
" \"sample_rate_hertz\": {\n",
" \"title\": \"Sample Rate Hertz\",\n",
" \"description\": \"The sample rate frequency of audio stream.\",\n",
" \"default\": 8000,\n",
" \"type\": \"integer\"\n",
" },\n",
" \"language_code\": {\n",
" \"title\": \"Language Code\",\n",
" \"description\": \"The [BCP-47 language code](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for the target language.\",\n",
" \"default\": \"en-US\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"url\": {\n",
" \"title\": \"Url\",\n",
" \"description\": \"The full URL where the Riva service can be found.\",\n",
" \"default\": \"http://localhost:50051\",\n",
" \"examples\": [\n",
" \"http://localhost:50051\",\n",
" \"https://user@pass:riva.example.com\"\n",
" ],\n",
" \"anyOf\": [\n",
" {\n",
" \"type\": \"string\",\n",
" \"minLength\": 1,\n",
" \"maxLength\": 65536,\n",
" \"format\": \"uri\"\n",
" },\n",
" {\n",
" \"type\": \"string\"\n",
" }\n",
" ]\n",
" },\n",
" \"ssl_cert\": {\n",
" \"title\": \"Ssl Cert\",\n",
" \"description\": \"A full path to the file where Riva's public ssl key can be read.\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"description\": {\n",
" \"title\": \"Description\",\n",
" \"default\": \"A Runnable for converting audio bytes to a string.This is useful for feeding an audio stream into a chain andpreprocessing that audio to create an LLM prompt.\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"audio_channel_count\": {\n",
" \"title\": \"Audio Channel Count\",\n",
" \"description\": \"The number of audio channels in the input audio stream.\",\n",
" \"default\": 1,\n",
" \"type\": \"integer\"\n",
" },\n",
" \"profanity_filter\": {\n",
" \"title\": \"Profanity Filter\",\n",
" \"description\": \"Controls whether or not Riva should attempt to filter profanity out of the transcribed text.\",\n",
" \"default\": true,\n",
" \"type\": \"boolean\"\n",
" },\n",
" \"enable_automatic_punctuation\": {\n",
" \"title\": \"Enable Automatic Punctuation\",\n",
" \"description\": \"Controls whether Riva should attempt to correct senetence puncuation in the transcribed text.\",\n",
" \"default\": true,\n",
" \"type\": \"boolean\"\n",
" }\n",
" },\n",
" \"definitions\": {\n",
" \"RivaAudioEncoding\": {\n",
" \"title\": \"RivaAudioEncoding\",\n",
" \"description\": \"An enum of the possible choices for Riva audio encoding.\\n\\nThe list of types exposed by the Riva GRPC Protobuf files can be found\\nwith the following commands:\\n```python\\nimport riva.client\\nprint(riva.client.AudioEncoding.keys()) # noqa: T201\\n```\",\n",
" \"enum\": [\n",
" \"ALAW\",\n",
" \"ENCODING_UNSPECIFIED\",\n",
" \"FLAC\",\n",
" \"LINEAR_PCM\",\n",
" \"MULAW\",\n",
" \"OGGOPUS\"\n",
" ],\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
"}\n",
"{\n",
" \"title\": \"RivaTTS\",\n",
" \"description\": \"A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva.\",\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"name\": {\n",
" \"title\": \"Name\",\n",
" \"default\": \"nvidia_riva_tts\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"encoding\": {\n",
" \"description\": \"The encoding on the audio stream.\",\n",
" \"default\": \"LINEAR_PCM\",\n",
" \"allOf\": [\n",
" {\n",
" \"$ref\": \"#/definitions/RivaAudioEncoding\"\n",
" }\n",
" ]\n",
" },\n",
" \"sample_rate_hertz\": {\n",
" \"title\": \"Sample Rate Hertz\",\n",
" \"description\": \"The sample rate frequency of audio stream.\",\n",
" \"default\": 8000,\n",
" \"type\": \"integer\"\n",
" },\n",
" \"language_code\": {\n",
" \"title\": \"Language Code\",\n",
" \"description\": \"The [BCP-47 language code](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for the target language.\",\n",
" \"default\": \"en-US\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"url\": {\n",
" \"title\": \"Url\",\n",
" \"description\": \"The full URL where the Riva service can be found.\",\n",
" \"default\": \"http://localhost:50051\",\n",
" \"examples\": [\n",
" \"http://localhost:50051\",\n",
" \"https://user@pass:riva.example.com\"\n",
" ],\n",
" \"anyOf\": [\n",
" {\n",
" \"type\": \"string\",\n",
" \"minLength\": 1,\n",
" \"maxLength\": 65536,\n",
" \"format\": \"uri\"\n",
" },\n",
" {\n",
" \"type\": \"string\"\n",
" }\n",
" ]\n",
" },\n",
" \"ssl_cert\": {\n",
" \"title\": \"Ssl Cert\",\n",
" \"description\": \"A full path to the file where Riva's public ssl key can be read.\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"description\": {\n",
" \"title\": \"Description\",\n",
" \"default\": \"A tool for converting text to speech.This is useful for converting LLM output into audio bytes.\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"voice_name\": {\n",
" \"title\": \"Voice Name\",\n",
" \"description\": \"The voice model in Riva to use for speech. Pre-trained models are documented in [the Riva documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html).\",\n",
" \"default\": \"English-US.Female-1\",\n",
" \"type\": \"string\"\n",
" },\n",
" \"output_directory\": {\n",
" \"title\": \"Output Directory\",\n",
" \"description\": \"The directory where all audio files should be saved. A null value indicates that wave files should not be saved. This is useful for debugging purposes.\",\n",
" \"type\": \"string\"\n",
" }\n",
" },\n",
" \"definitions\": {\n",
" \"RivaAudioEncoding\": {\n",
" \"title\": \"RivaAudioEncoding\",\n",
" \"description\": \"An enum of the possible choices for Riva audio encoding.\\n\\nThe list of types exposed by the Riva GRPC Protobuf files can be found\\nwith the following commands:\\n```python\\nimport riva.client\\nprint(riva.client.AudioEncoding.keys()) # noqa: T201\\n```\",\n",
" \"enum\": [\n",
" \"ALAW\",\n",
" \"ENCODING_UNSPECIFIED\",\n",
" \"FLAC\",\n",
" \"LINEAR_PCM\",\n",
" \"MULAW\",\n",
" \"OGGOPUS\"\n",
" ],\n",
" \"type\": \"string\"\n",
" }\n",
" }\n",
"}\n"
]
}
],
"source": [
"print(json.dumps(RivaASR.schema(), indent=2))\n",
"print(json.dumps(RivaTTS.schema(), indent=2))"
]
},
{
"cell_type": "markdown",
"id": "2f128f27",
"metadata": {},
"source": [
"## 5. Declare Riva ASR and Riva TTS Runnables\n",
"\n",
"For this example, a single-channel audio file (mulaw format, so `.wav`) is used.\n",
"\n",
"You will need a Riva speech server setup, so if you don't have a Riva speech server, go to [Setup](#3-setup).\n",
"\n",
"### a. Set Audio Parameters\n",
"Some parameters of audio can be inferred by the mulaw file, but others are set explicitly.\n",
"\n",
"Replace `audio_file` with the path of your audio file."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "5c75995a",
"metadata": {},
"outputs": [],
"source": [
"import pywav # pywav is used instead of built-in wave because of mulaw support\n",
"from langchain_community.utilities.nvidia_riva import RivaAudioEncoding\n",
"\n",
"audio_file = \"./audio_files/en-US_sample2.wav\"\n",
"wav_file = pywav.WavRead(audio_file)\n",
"audio_data = wav_file.getdata()\n",
"audio_encoding = RivaAudioEncoding.from_wave_format_code(wav_file.getaudioformat())\n",
"sample_rate = wav_file.getsamplerate()\n",
"delay_time = 1 / 4\n",
"chunk_size = int(sample_rate * delay_time)\n",
"delay_time = 1 / 8\n",
"num_channels = wav_file.getnumofchannels()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a3b29f36",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import IPython\n",
"\n",
"IPython.display.Audio(audio_file)"
]
},
{
"cell_type": "markdown",
"id": "fb294e19",
"metadata": {},
"source": [
"### b. Set the Speech Server and Declare Riva LangChain Runnables\n",
"\n",
"Be sure to set `RIVA_SPEECH_URL` to be the URI of your Riva speech server.\n",
"\n",
"The runnables act as clients to the speech server. Many of the fields set in this example are configured based on the sample audio data. "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cf1108af",
"metadata": {},
"outputs": [],
"source": [
"RIVA_SPEECH_URL = \"http://localhost:50051/\"\n",
"\n",
"riva_asr = RivaASR(\n",
" url=RIVA_SPEECH_URL, # the location of the Riva ASR server\n",
" encoding=audio_encoding,\n",
" audio_channel_count=num_channels,\n",
" sample_rate_hertz=sample_rate,\n",
" profanity_filter=True,\n",
" enable_automatic_punctuation=True,\n",
" language_code=\"en-US\",\n",
")\n",
"\n",
"riva_tts = RivaTTS(\n",
" url=RIVA_SPEECH_URL, # the location of the Riva TTS server\n",
" output_directory=\"./scratch\", # location of the output .wav files\n",
" language_code=\"en-US\",\n",
" voice_name=\"English-US.Female-1\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f12049a2",
"metadata": {},
"source": [
"## 6. Create Additional Chain Components\n",
"As usual, declare the other parts of the chain. In this case, it's just a prompt template and an LLM.\n",
"\n",
"LangChain compatible NVIDIA LLMs from [NVIDIA AI Foundation Endpoints](https://www.nvidia.com/en-us/ai-data-science/foundation-models/) can also be used by following these [instructions](https://python.langchain.com/docs/integrations/chat/nvidia_ai_endpoints). "
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a6deb471",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import PromptTemplate\n",
"from langchain_openai import OpenAI\n",
"\n",
"prompt = PromptTemplate.from_template(\"{user_input}\")\n",
"llm = OpenAI(openai_api_key=\"sk-xxx\")"
]
},
{
"cell_type": "markdown",
"id": "5cca78f1",
"metadata": {},
"source": [
"Now, tie together all the parts of the chain including RivaASR and RivaTTS."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c8de3b75",
"metadata": {},
"outputs": [],
"source": [
"chain = {\"user_input\": riva_asr} | prompt | llm | riva_tts"
]
},
{
"cell_type": "markdown",
"id": "84c2c6dc",
"metadata": {},
"source": [
"## 7. Run the Chain with Streamed Inputs and Outputs\n",
"\n",
"### a. Mimic Audio Streaming\n",
"To mimic streaming, first convert the processed audio data to iterable chunks of audio bytes. \n",
"\n",
"Two functions, `producer` and `consumer`, respectively handle asynchronously passing audio data into the chain and consuming audio data out of the chain.\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "745ee427",
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"\n",
"from langchain_community.utilities.nvidia_riva import AudioStream\n",
"\n",
"audio_chunks = [\n",
" audio_data[0 + i : chunk_size + i] for i in range(0, len(audio_data), chunk_size)\n",
"]\n",
"\n",
"\n",
"async def producer(input_stream) -> None:\n",
" \"\"\"Produces audio chunk bytes into an AudioStream as streaming audio input.\"\"\"\n",
" for chunk in audio_chunks:\n",
" await input_stream.aput(chunk)\n",
" input_stream.close()\n",
"\n",
"\n",
"async def consumer(input_stream, output_stream) -> None:\n",
" \"\"\"\n",
" Consumes audio chunks from input stream and passes them along the chain\n",
" constructed comprised of ASR -> text based prompt for an LLM -> TTS chunks\n",
" with synthesized voice of LLM response put in an output stream.\n",
" \"\"\"\n",
" while not input_stream.complete:\n",
" async for chunk in chain.astream(input_stream):\n",
" await output_stream.put(\n",
" chunk\n",
" ) # for production code don't forget to add a timeout\n",
"\n",
"\n",
"input_stream = AudioStream(maxsize=1000)\n",
"output_stream = asyncio.Queue()\n",
"\n",
"# send data into the chain\n",
"producer_task = asyncio.create_task(producer(input_stream))\n",
"# get data out of the chain\n",
"consumer_task = asyncio.create_task(consumer(input_stream, output_stream))\n",
"\n",
"while not consumer_task.done():\n",
" try:\n",
" generated_audio = await asyncio.wait_for(\n",
" output_stream.get(), timeout=2\n",
" ) # for production code don't forget to add a timeout\n",
" except asyncio.TimeoutError:\n",
" continue\n",
"\n",
"await producer_task\n",
"await consumer_task"
]
},
{
"cell_type": "markdown",
"id": "76b8f175",
"metadata": {},
"source": [
"## 8. Listen to Voice Response\n",
"\n",
"The audio response is written to `./scratch` and should contain an audio clip that is a response to the input audio."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "8f41b939",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import glob\n",
"import os\n",
"\n",
"output_path = os.path.join(os.getcwd(), \"scratch\")\n",
"file_type = \"*.wav\"\n",
"files_path = os.path.join(output_path, file_type)\n",
"files = glob.glob(files_path)\n",
"\n",
"IPython.display.Audio(files[0])"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}