mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-18 16:16:33 +00:00
community[minor]: Update Azure Cognitive Services to Azure AI Services (#19488)
This is a follow up to #18371. These are the changes: - New **Azure AI Services** toolkit and tools to replace those of **Azure Cognitive Services**. - Updated documentation for Microsoft platform. - The image analysis tool has been rewritten to use the new package `azure-ai-vision-imageanalysis`, doing a proper replacement of `azure-ai-vision`. These changes: - Update outdated naming from "Azure Cognitive Services" to "Azure AI Services". - Update documentation to use non-deprecated methods to create and use agents. - Removes need to depend on yanked python package (`azure-ai-vision`) There is one new dependency that is needed as a replacement to `azure-ai-vision`: - `azure-ai-vision-imageanalysis`. This is optional and declared within a function. There is a new `azure_ai_services.ipynb` notebook showing usage; Changes have been linted and formatted. I am leaving the actions of adding deprecation notices and future removal of Azure Cognitive Services up to the LangChain team, as I am not sure what the current practice around this is. --- If this PR makes it, my handle is @galo@mastodon.social --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ tool for the job.
|
||||
|
||||
CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
@@ -31,6 +32,11 @@ _module_lookup = {
|
||||
"AIPluginTool": "langchain_community.tools.plugin",
|
||||
"APIOperation": "langchain_community.tools.openapi.utils.api_models",
|
||||
"ArxivQueryRun": "langchain_community.tools.arxiv.tool",
|
||||
"AzureAiServicesDocumentIntelligenceTool": "langchain_community.tools.azure_ai_services", # noqa: E501
|
||||
"AzureAiServicesImageAnalysisTool": "langchain_community.tools.azure_ai_services",
|
||||
"AzureAiServicesSpeechToTextTool": "langchain_community.tools.azure_ai_services",
|
||||
"AzureAiServicesTextToSpeechTool": "langchain_community.tools.azure_ai_services",
|
||||
"AzureAiServicesTextAnalyticsForHealthTool": "langchain_community.tools.azure_ai_services", # noqa: E501
|
||||
"AzureCogsFormRecognizerTool": "langchain_community.tools.azure_cognitive_services",
|
||||
"AzureCogsImageAnalysisTool": "langchain_community.tools.azure_cognitive_services",
|
||||
"AzureCogsSpeech2TextTool": "langchain_community.tools.azure_cognitive_services",
|
||||
|
@@ -0,0 +1,25 @@
|
||||
"""Azure AI Services Tools."""
|
||||
|
||||
from langchain_community.tools.azure_ai_services.document_intelligence import (
|
||||
AzureAiServicesDocumentIntelligenceTool,
|
||||
)
|
||||
from langchain_community.tools.azure_ai_services.image_analysis import (
|
||||
AzureAiServicesImageAnalysisTool,
|
||||
)
|
||||
from langchain_community.tools.azure_ai_services.speech_to_text import (
|
||||
AzureAiServicesSpeechToTextTool,
|
||||
)
|
||||
from langchain_community.tools.azure_ai_services.text_analytics_for_health import (
|
||||
AzureAiServicesTextAnalyticsForHealthTool,
|
||||
)
|
||||
from langchain_community.tools.azure_ai_services.text_to_speech import (
|
||||
AzureAiServicesTextToSpeechTool,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAiServicesDocumentIntelligenceTool",
|
||||
"AzureAiServicesImageAnalysisTool",
|
||||
"AzureAiServicesSpeechToTextTool",
|
||||
"AzureAiServicesTextToSpeechTool",
|
||||
"AzureAiServicesTextAnalyticsForHealthTool",
|
||||
]
|
@@ -0,0 +1,145 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.tools.azure_ai_services.utils import (
|
||||
detect_file_src_type,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAiServicesDocumentIntelligenceTool(BaseTool):
|
||||
"""Tool that queries the Azure AI Services Document Intelligence API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/document-intelligence/quickstarts/get-started-sdks-rest-api?view=doc-intel-4.0.0&pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_ai_services_key: str = "" #: :meta private:
|
||||
azure_ai_services_endpoint: str = "" #: :meta private:
|
||||
doc_analysis_client: Any #: :meta private:
|
||||
|
||||
name: str = "azure_ai_services_document_intelligence"
|
||||
description: str = (
|
||||
"A wrapper around Azure AI Services Document Intelligence. "
|
||||
"Useful for when you need to "
|
||||
"extract text, tables, and key-value pairs from documents. "
|
||||
"Input should be a url to a document."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_ai_services_key = get_from_dict_or_env(
|
||||
values, "azure_ai_services_key", "AZURE_AI_SERVICES_KEY"
|
||||
)
|
||||
|
||||
azure_ai_services_endpoint = get_from_dict_or_env(
|
||||
values, "azure_ai_services_endpoint", "AZURE_AI_SERVICES_ENDPOINT"
|
||||
)
|
||||
|
||||
try:
|
||||
from azure.ai.formrecognizer import DocumentAnalysisClient
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
values["doc_analysis_client"] = DocumentAnalysisClient(
|
||||
endpoint=azure_ai_services_endpoint,
|
||||
credential=AzureKeyCredential(azure_ai_services_key),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-ai-formrecognizer is not installed. "
|
||||
"Run `pip install azure-ai-formrecognizer` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _parse_tables(self, tables: List[Any]) -> List[Any]:
|
||||
result = []
|
||||
for table in tables:
|
||||
rc, cc = table.row_count, table.column_count
|
||||
_table = [["" for _ in range(cc)] for _ in range(rc)]
|
||||
for cell in table.cells:
|
||||
_table[cell.row_index][cell.column_index] = cell.content
|
||||
result.append(_table)
|
||||
return result
|
||||
|
||||
def _parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Any]:
|
||||
result = []
|
||||
for kv_pair in kv_pairs:
|
||||
key = kv_pair.key.content if kv_pair.key else ""
|
||||
value = kv_pair.value.content if kv_pair.value else ""
|
||||
result.append((key, value))
|
||||
return result
|
||||
|
||||
def _document_analysis(self, document_path: str) -> Dict:
|
||||
document_src_type = detect_file_src_type(document_path)
|
||||
if document_src_type == "local":
|
||||
with open(document_path, "rb") as document:
|
||||
poller = self.doc_analysis_client.begin_analyze_document(
|
||||
"prebuilt-document", document
|
||||
)
|
||||
elif document_src_type == "remote":
|
||||
poller = self.doc_analysis_client.begin_analyze_document_from_url(
|
||||
"prebuilt-document", document_path
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid document path: {document_path}")
|
||||
|
||||
result = poller.result()
|
||||
res_dict = {}
|
||||
|
||||
if result.content is not None:
|
||||
res_dict["content"] = result.content
|
||||
|
||||
if result.tables is not None:
|
||||
res_dict["tables"] = self._parse_tables(result.tables)
|
||||
|
||||
if result.key_value_pairs is not None:
|
||||
res_dict["key_value_pairs"] = self._parse_kv_pairs(result.key_value_pairs)
|
||||
|
||||
return res_dict
|
||||
|
||||
def _format_document_analysis_result(self, document_analysis_result: Dict) -> str:
|
||||
formatted_result = []
|
||||
if "content" in document_analysis_result:
|
||||
formatted_result.append(
|
||||
f"Content: {document_analysis_result['content']}".replace("\n", " ")
|
||||
)
|
||||
|
||||
if "tables" in document_analysis_result:
|
||||
for i, table in enumerate(document_analysis_result["tables"]):
|
||||
formatted_result.append(f"Table {i}: {table}".replace("\n", " "))
|
||||
|
||||
if "key_value_pairs" in document_analysis_result:
|
||||
for kv_pair in document_analysis_result["key_value_pairs"]:
|
||||
formatted_result.append(
|
||||
f"{kv_pair[0]}: {kv_pair[1]}".replace("\n", " ")
|
||||
)
|
||||
|
||||
return "\n".join(formatted_result)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
document_analysis_result = self._document_analysis(query)
|
||||
if not document_analysis_result:
|
||||
return "No good document analysis result was found"
|
||||
|
||||
return self._format_document_analysis_result(document_analysis_result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error while running AzureAiServicesDocumentIntelligenceTool: {e}"
|
||||
)
|
@@ -0,0 +1,153 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.tools.azure_ai_services.utils import (
|
||||
detect_file_src_type,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAiServicesImageAnalysisTool(BaseTool):
|
||||
"""Tool that queries the Azure AI Services Image Analysis API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/quickstarts-sdk/image-analysis-client-library-40
|
||||
"""
|
||||
|
||||
azure_ai_services_key: str = "" #: :meta private:
|
||||
azure_ai_services_endpoint: str = "" #: :meta private:
|
||||
image_analysis_client: Any #: :meta private:
|
||||
visual_features: Any #: :meta private:
|
||||
|
||||
name: str = "azure_ai_services_image_analysis"
|
||||
description: str = (
|
||||
"A wrapper around Azure AI Services Image Analysis. "
|
||||
"Useful for when you need to analyze images. "
|
||||
"Input should be a url to an image."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_ai_services_key = get_from_dict_or_env(
|
||||
values, "azure_ai_services_key", "AZURE_AI_SERVICES_KEY"
|
||||
)
|
||||
|
||||
azure_ai_services_endpoint = get_from_dict_or_env(
|
||||
values, "azure_ai_services_endpoint", "AZURE_AI_SERVICES_ENDPOINT"
|
||||
)
|
||||
|
||||
"""Validate that azure-ai-vision-imageanalysis is installed."""
|
||||
try:
|
||||
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
||||
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-ai-vision-imageanalysis is not installed. "
|
||||
"Run `pip install azure-ai-vision-imageanalysis` to install. "
|
||||
)
|
||||
|
||||
"""Validate Azure AI Vision Image Analysis client can be initialized."""
|
||||
try:
|
||||
values["image_analysis_client"] = ImageAnalysisClient(
|
||||
endpoint=azure_ai_services_endpoint,
|
||||
credential=AzureKeyCredential(azure_ai_services_key),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Initialization of Azure AI Vision Image Analysis client failed: {e}"
|
||||
)
|
||||
|
||||
values["visual_features"] = [
|
||||
VisualFeatures.TAGS,
|
||||
VisualFeatures.OBJECTS,
|
||||
VisualFeatures.CAPTION,
|
||||
VisualFeatures.READ,
|
||||
]
|
||||
|
||||
return values
|
||||
|
||||
def _image_analysis(self, image_path: str) -> Dict:
|
||||
try:
|
||||
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
self.image_analysis_client: ImageAnalysisClient
|
||||
|
||||
image_src_type = detect_file_src_type(image_path)
|
||||
if image_src_type == "local":
|
||||
with open(image_path, "rb") as image_file:
|
||||
image_data = image_file.read()
|
||||
result = self.image_analysis_client.analyze(
|
||||
image_data=image_data,
|
||||
visual_features=self.visual_features,
|
||||
)
|
||||
elif image_src_type == "remote":
|
||||
result = self.image_analysis_client.analyze_from_url(
|
||||
image_url=image_path,
|
||||
visual_features=self.visual_features,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid image path: {image_path}")
|
||||
|
||||
res_dict = {}
|
||||
if result:
|
||||
if result.caption is not None:
|
||||
res_dict["caption"] = result.caption.text
|
||||
|
||||
if result.objects is not None:
|
||||
res_dict["objects"] = [obj.tags[0].name for obj in result.objects.list]
|
||||
|
||||
if result.tags is not None:
|
||||
res_dict["tags"] = [tag.name for tag in result.tags.list]
|
||||
|
||||
if result.read is not None and len(result.read.blocks) > 0:
|
||||
res_dict["text"] = [line.text for line in result.read.blocks[0].lines]
|
||||
|
||||
return res_dict
|
||||
|
||||
def _format_image_analysis_result(self, image_analysis_result: Dict) -> str:
|
||||
formatted_result = []
|
||||
if "caption" in image_analysis_result:
|
||||
formatted_result.append("Caption: " + image_analysis_result["caption"])
|
||||
|
||||
if (
|
||||
"objects" in image_analysis_result
|
||||
and len(image_analysis_result["objects"]) > 0
|
||||
):
|
||||
formatted_result.append(
|
||||
"Objects: " + ", ".join(image_analysis_result["objects"])
|
||||
)
|
||||
|
||||
if "tags" in image_analysis_result and len(image_analysis_result["tags"]) > 0:
|
||||
formatted_result.append("Tags: " + ", ".join(image_analysis_result["tags"]))
|
||||
|
||||
if "text" in image_analysis_result and len(image_analysis_result["text"]) > 0:
|
||||
formatted_result.append("Text: " + ", ".join(image_analysis_result["text"]))
|
||||
|
||||
return "\n".join(formatted_result)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
image_analysis_result = self._image_analysis(query)
|
||||
if not image_analysis_result:
|
||||
return "No good image analysis result was found"
|
||||
|
||||
return self._format_image_analysis_result(image_analysis_result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error while running AzureAiImageAnalysisTool: {e}")
|
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
from langchain_community.tools.azure_ai_services.utils import (
|
||||
detect_file_src_type,
|
||||
download_audio_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAiServicesSpeechToTextTool(BaseTool):
|
||||
"""Tool that queries the Azure AI Services Speech to Text API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/speech-service/get-started-speech-to-text?pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_ai_services_key: str = "" #: :meta private:
|
||||
azure_ai_services_region: str = "" #: :meta private:
|
||||
speech_language: str = "en-US" #: :meta private:
|
||||
speech_config: Any #: :meta private:
|
||||
|
||||
name: str = "azure_ai_services_speech_to_text"
|
||||
description: str = (
|
||||
"A wrapper around Azure AI Services Speech to Text. "
|
||||
"Useful for when you need to transcribe audio to text. "
|
||||
"Input should be a url to an audio file."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_ai_services_key = get_from_dict_or_env(
|
||||
values, "azure_ai_services_key", "AZURE_AI_SERVICES_KEY"
|
||||
)
|
||||
|
||||
azure_ai_services_region = get_from_dict_or_env(
|
||||
values, "azure_ai_services_region", "AZURE_AI_SERVICES_REGION"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
|
||||
values["speech_config"] = speechsdk.SpeechConfig(
|
||||
subscription=azure_ai_services_key, region=azure_ai_services_region
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-cognitiveservices-speech is not installed. "
|
||||
"Run `pip install azure-cognitiveservices-speech` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _continuous_recognize(self, speech_recognizer: Any) -> str:
|
||||
done = False
|
||||
text = ""
|
||||
|
||||
def stop_cb(evt: Any) -> None:
|
||||
"""callback that stop continuous recognition"""
|
||||
speech_recognizer.stop_continuous_recognition_async()
|
||||
nonlocal done
|
||||
done = True
|
||||
|
||||
def retrieve_cb(evt: Any) -> None:
|
||||
"""callback that retrieves the intermediate recognition results"""
|
||||
nonlocal text
|
||||
text += evt.result.text
|
||||
|
||||
# retrieve text on recognized events
|
||||
speech_recognizer.recognized.connect(retrieve_cb)
|
||||
# stop continuous recognition on either session stopped or canceled events
|
||||
speech_recognizer.session_stopped.connect(stop_cb)
|
||||
speech_recognizer.canceled.connect(stop_cb)
|
||||
|
||||
# Start continuous speech recognition
|
||||
speech_recognizer.start_continuous_recognition_async()
|
||||
while not done:
|
||||
time.sleep(0.5)
|
||||
return text
|
||||
|
||||
def _speech_to_text(self, audio_path: str, speech_language: str) -> str:
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
audio_src_type = detect_file_src_type(audio_path)
|
||||
if audio_src_type == "local":
|
||||
audio_config = speechsdk.AudioConfig(filename=audio_path)
|
||||
elif audio_src_type == "remote":
|
||||
tmp_audio_path = download_audio_from_url(audio_path)
|
||||
audio_config = speechsdk.AudioConfig(filename=tmp_audio_path)
|
||||
else:
|
||||
raise ValueError(f"Invalid audio path: {audio_path}")
|
||||
|
||||
self.speech_config.speech_recognition_language = speech_language
|
||||
speech_recognizer = speechsdk.SpeechRecognizer(self.speech_config, audio_config)
|
||||
return self._continuous_recognize(speech_recognizer)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
text = self._speech_to_text(query, self.speech_language)
|
||||
return text
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error while running AzureAiServicesSpeechToTextTool: {e}"
|
||||
)
|
@@ -0,0 +1,104 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAiServicesTextAnalyticsForHealthTool(BaseTool):
|
||||
"""Tool that queries the Azure AI Services Text Analytics for Health API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/quickstart?pivots=programming-language-python
|
||||
"""
|
||||
|
||||
azure_ai_services_key: str = "" #: :meta private:
|
||||
azure_ai_services_endpoint: str = "" #: :meta private:
|
||||
text_analytics_client: Any #: :meta private:
|
||||
|
||||
name: str = "azure_ai_services_text_analytics_for_health"
|
||||
description: str = (
|
||||
"A wrapper around Azure AI Services Text Analytics for Health. "
|
||||
"Useful for when you need to identify entities in healthcare data. "
|
||||
"Input should be text."
|
||||
)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_ai_services_key = get_from_dict_or_env(
|
||||
values, "azure_ai_services_key", "AZURE_AI_SERVICES_KEY"
|
||||
)
|
||||
|
||||
azure_ai_services_endpoint = get_from_dict_or_env(
|
||||
values, "azure_ai_services_endpoint", "AZURE_AI_SERVICES_ENDPOINT"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.ai.textanalytics as sdk
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
|
||||
values["text_analytics_client"] = sdk.TextAnalyticsClient(
|
||||
endpoint=azure_ai_services_endpoint,
|
||||
credential=AzureKeyCredential(azure_ai_services_key),
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-ai-textanalytics is not installed. "
|
||||
"Run `pip install azure-ai-textanalytics` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _text_analysis(self, text: str) -> Dict:
|
||||
poller = self.text_analytics_client.begin_analyze_healthcare_entities(
|
||||
[{"id": "1", "language": "en", "text": text}]
|
||||
)
|
||||
|
||||
result = poller.result()
|
||||
|
||||
res_dict = {}
|
||||
|
||||
docs = [doc for doc in result if not doc.is_error]
|
||||
|
||||
if docs is not None:
|
||||
res_dict["entities"] = [
|
||||
f"{x.text} is a healthcare entity of type {x.category}"
|
||||
for y in docs
|
||||
for x in y.entities
|
||||
]
|
||||
|
||||
return res_dict
|
||||
|
||||
def _format_text_analysis_result(self, text_analysis_result: Dict) -> str:
|
||||
formatted_result = []
|
||||
if "entities" in text_analysis_result:
|
||||
formatted_result.append(
|
||||
f"""The text contains the following healthcare entities: {
|
||||
', '.join(text_analysis_result['entities'])
|
||||
}""".replace("\n", " ")
|
||||
)
|
||||
|
||||
return "\n".join(formatted_result)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
text_analysis_result = self._text_analysis(query)
|
||||
|
||||
return self._format_text_analysis_result(text_analysis_result)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error while running AzureAiServicesTextAnalyticsForHealthTool: {e}"
|
||||
)
|
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AzureAiServicesTextToSpeechTool(BaseTool):
|
||||
"""Tool that queries the Azure AI Services Text to Speech API.
|
||||
|
||||
In order to set this up, follow instructions at:
|
||||
https://learn.microsoft.com/en-us/azure/ai-services/speech-service/get-started-text-to-speech?pivots=programming-language-python
|
||||
"""
|
||||
|
||||
name: str = "azure_ai_services_text_to_speech"
|
||||
description: str = (
|
||||
"A wrapper around Azure AI Services Text to Speech API. "
|
||||
"Useful for when you need to convert text to speech. "
|
||||
)
|
||||
return_direct: bool = True
|
||||
|
||||
azure_ai_services_key: str = "" #: :meta private:
|
||||
azure_ai_services_region: str = "" #: :meta private:
|
||||
speech_language: str = "en-US" #: :meta private:
|
||||
speech_config: Any #: :meta private:
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and endpoint exists in environment."""
|
||||
azure_ai_services_key = get_from_dict_or_env(
|
||||
values, "azure_ai_services_key", "AZURE_AI_SERVICES_KEY"
|
||||
)
|
||||
|
||||
azure_ai_services_region = get_from_dict_or_env(
|
||||
values, "azure_ai_services_region", "AZURE_AI_SERVICES_REGION"
|
||||
)
|
||||
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
|
||||
values["speech_config"] = speechsdk.SpeechConfig(
|
||||
subscription=azure_ai_services_key, region=azure_ai_services_region
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-cognitiveservices-speech is not installed. "
|
||||
"Run `pip install azure-cognitiveservices-speech` to install."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def _text_to_speech(self, text: str, speech_language: str) -> str:
|
||||
try:
|
||||
import azure.cognitiveservices.speech as speechsdk
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
self.speech_config.speech_synthesis_language = speech_language
|
||||
speech_synthesizer = speechsdk.SpeechSynthesizer(
|
||||
speech_config=self.speech_config, audio_config=None
|
||||
)
|
||||
result = speech_synthesizer.speak_text(text)
|
||||
|
||||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted:
|
||||
stream = speechsdk.AudioDataStream(result)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="wb", suffix=".wav", delete=False
|
||||
) as f:
|
||||
stream.save_to_wav_file(f.name)
|
||||
|
||||
return f.name
|
||||
|
||||
elif result.reason == speechsdk.ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.debug(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
if cancellation_details.reason == speechsdk.CancellationReason.Error:
|
||||
raise RuntimeError(
|
||||
f"Speech synthesis error: {cancellation_details.error_details}"
|
||||
)
|
||||
|
||||
return "Speech synthesis canceled."
|
||||
|
||||
else:
|
||||
return f"Speech synthesis failed: {result.reason}"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
run_manager: Optional[CallbackManagerForToolRun] = None,
|
||||
) -> str:
|
||||
"""Use the tool."""
|
||||
try:
|
||||
speech_file = self._text_to_speech(query, self.speech_language)
|
||||
return speech_file
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Error while running AzureAiServicesTextToSpeechTool: {e}"
|
||||
)
|
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def detect_file_src_type(file_path: str) -> str:
|
||||
"""Detect if the file is local or remote."""
|
||||
if os.path.isfile(file_path):
|
||||
return "local"
|
||||
|
||||
parsed_url = urlparse(file_path)
|
||||
if parsed_url.scheme and parsed_url.netloc:
|
||||
return "remote"
|
||||
|
||||
return "invalid"
|
||||
|
||||
|
||||
def download_audio_from_url(audio_url: str) -> str:
|
||||
"""Download audio from url to local."""
|
||||
ext = audio_url.split(".")[-1]
|
||||
response = requests.get(audio_url, stream=True)
|
||||
response.raise_for_status()
|
||||
with tempfile.NamedTemporaryFile(mode="wb", suffix=f".{ext}", delete=False) as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
return f.name
|
Reference in New Issue
Block a user