mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-30 19:49:09 +00:00
**Please tag this issue with `nvidia_genai`** - **Description:** Added new Runnables for integration NVIDIA Riva into LCEL chains for Automatic Speech Recognition (ASR) and Text To Speech (TTS). - **Issue:** N/A - **Dependencies:** To use these runnables, the NVIDIA Riva client libraries are required. It they are not installed, an error will be raised instructing how to install them. The Runnables can be safely imported without the riva client libraries. - **Twitter handle:** N/A All of the Riva Runnables are inside a single folder in the Utilities module. In this folder are four files: - common.py - Contains all code that is common to both TTS and ASR - stream.py - Contains a class representing an audio stream that allows the end user to put data into the stream like a queue. - asr.py - Contains the RivaASR runnable - tts.py - Contains the RivaTTS runnable The following Python function is an example of creating a chain that makes use of both of these Runnables: ```python def create( config: Configuration, audio_encoding: RivaAudioEncoding, sample_rate: int, audio_channels: int = 1, ) -> Runnable[ASRInputType, TTSOutputType]: """Create a new instance of the chain.""" _LOGGER.info("Instantiating the chain.") # create the riva asr client riva_asr = RivaASR( url=str(config.riva_asr.service.url), ssl_cert=config.riva_asr.service.ssl_cert, encoding=audio_encoding, audio_channel_count=audio_channels, sample_rate_hertz=sample_rate, profanity_filter=config.riva_asr.profanity_filter, enable_automatic_punctuation=config.riva_asr.enable_automatic_punctuation, language_code=config.riva_asr.language_code, ) # create the prompt template prompt = PromptTemplate.from_template("{user_input}") # model = ChatOpenAI() model = ChatNVIDIA(model="mixtral_8x7b") # type: ignore # create the riva tts client riva_tts = RivaTTS( url=str(config.riva_asr.service.url), ssl_cert=config.riva_asr.service.ssl_cert, output_directory=config.riva_tts.output_directory, language_code=config.riva_tts.language_code, voice_name=config.riva_tts.voice_name, ) # construct and return the chain return {"user_input": riva_asr} | prompt | model | riva_tts # type: ignore ``` The following code is an example of creating a new audio stream for Riva: ```python input_stream = AudioStream(maxsize=1000) # Send bytes into the stream for chunk in audio_chunks: await input_stream.aput(chunk) input_stream.close() ``` The following code is an example of how to execute the chain with RivaASR and RivaTTS ```python output_stream = asyncio.Queue() while not input_stream.complete: async for chunk in chain.astream(input_stream): output_stream.put(chunk) ``` Everything should be async safe and thread safe. Audio data can be put into the input stream while the chain is running without interruptions. --------- Co-authored-by: Hayden Wolff <hwolff@nvidia.com> Co-authored-by: Hayden Wolff <hwolff@Haydens-Laptop.local> Co-authored-by: Hayden Wolff <haydenwolff99@gmail.com> Co-authored-by: Erick Friis <erick@langchain.dev>
58 lines
1.4 KiB
Python
58 lines
1.4 KiB
Python
from langchain_community.utilities import __all__
|
|
|
|
EXPECTED_ALL = [
|
|
"AlphaVantageAPIWrapper",
|
|
"ApifyWrapper",
|
|
"ArceeWrapper",
|
|
"ArxivAPIWrapper",
|
|
"BibtexparserWrapper",
|
|
"BingSearchAPIWrapper",
|
|
"BraveSearchWrapper",
|
|
"DuckDuckGoSearchAPIWrapper",
|
|
"GoldenQueryAPIWrapper",
|
|
"GoogleFinanceAPIWrapper",
|
|
"GoogleJobsAPIWrapper",
|
|
"GoogleLensAPIWrapper",
|
|
"GooglePlacesAPIWrapper",
|
|
"GoogleScholarAPIWrapper",
|
|
"GoogleSearchAPIWrapper",
|
|
"GoogleSerperAPIWrapper",
|
|
"GoogleTrendsAPIWrapper",
|
|
"GraphQLAPIWrapper",
|
|
"JiraAPIWrapper",
|
|
"LambdaWrapper",
|
|
"MaxComputeAPIWrapper",
|
|
"MetaphorSearchAPIWrapper",
|
|
"NasaAPIWrapper",
|
|
"NVIDIARivaASR",
|
|
"NVIDIARivaTTS",
|
|
"NVIDIARivaStream",
|
|
"OpenWeatherMapAPIWrapper",
|
|
"OutlineAPIWrapper",
|
|
"Portkey",
|
|
"PowerBIDataset",
|
|
"PubMedAPIWrapper",
|
|
"PythonREPL",
|
|
"Requests",
|
|
"RequestsWrapper",
|
|
"SQLDatabase",
|
|
"SceneXplainAPIWrapper",
|
|
"SearchApiAPIWrapper",
|
|
"SearxSearchWrapper",
|
|
"SerpAPIWrapper",
|
|
"SparkSQL",
|
|
"StackExchangeAPIWrapper",
|
|
"SteamWebAPIWrapper",
|
|
"TensorflowDatasets",
|
|
"TextRequestsWrapper",
|
|
"TwilioAPIWrapper",
|
|
"WikipediaAPIWrapper",
|
|
"WolframAlphaAPIWrapper",
|
|
"ZapierNLAWrapper",
|
|
"MerriamWebsterAPIWrapper",
|
|
]
|
|
|
|
|
|
def test_all_imports() -> None:
|
|
assert set(__all__) == set(EXPECTED_ALL)
|