FEAT: Add azure cognitive health tool (#13448)

- **Description:** This change adds an agent to the Azure Cognitive
Services toolkit for identifying healthcare entities
  - **Dependencies:** azure-ai-textanalytics (Optional)

---------

Co-authored-by: James Beck <James.Beck@sa.gov.au>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
jwbeck97
2023-11-20 13:14:01 +10:30
committed by GitHub
parent 6bf9b2cb51
commit a93616e972
9 changed files with 213 additions and 10 deletions

View File

@@ -9,6 +9,7 @@ from langchain.tools.azure_cognitive_services import (
AzureCogsImageAnalysisTool,
AzureCogsSpeech2TextTool,
AzureCogsText2SpeechTool,
AzureCogsTextAnalyticsHealthTool,
)
from langchain.tools.base import BaseTool
@@ -23,6 +24,7 @@ class AzureCognitiveServicesToolkit(BaseToolkit):
AzureCogsFormRecognizerTool(),
AzureCogsSpeech2TextTool(),
AzureCogsText2SpeechTool(),
AzureCogsTextAnalyticsHealthTool(),
]
# TODO: Remove check once azure-ai-vision supports MacOS.

View File

@@ -84,6 +84,14 @@ def _import_azure_cognitive_services_AzureCogsText2SpeechTool() -> Any:
return AzureCogsText2SpeechTool
def _import_azure_cognitive_services_AzureCogsTextAnalyticsHealthTool() -> Any:
from langchain.tools.azure_cognitive_services import (
AzureCogsTextAnalyticsHealthTool,
)
return AzureCogsTextAnalyticsHealthTool
def _import_bing_search_tool_BingSearchResults() -> Any:
from langchain.tools.bing_search.tool import BingSearchResults
@@ -691,6 +699,8 @@ def __getattr__(name: str) -> Any:
return _import_azure_cognitive_services_AzureCogsSpeech2TextTool()
elif name == "AzureCogsText2SpeechTool":
return _import_azure_cognitive_services_AzureCogsText2SpeechTool()
elif name == "AzureCogsTextAnalyticsHealthTool":
return _import_azure_cognitive_services_AzureCogsTextAnalyticsHealthTool()
elif name == "BingSearchResults":
return _import_bing_search_tool_BingSearchResults()
elif name == "BingSearchRun":
@@ -900,6 +910,7 @@ __all__ = [
"AzureCogsImageAnalysisTool",
"AzureCogsSpeech2TextTool",
"AzureCogsText2SpeechTool",
"AzureCogsTextAnalyticsHealthTool",
"BaseGraphQLTool",
"BaseRequestsTool",
"BaseSQLDatabaseTool",

View File

@@ -12,10 +12,14 @@ from langchain.tools.azure_cognitive_services.speech2text import (
from langchain.tools.azure_cognitive_services.text2speech import (
AzureCogsText2SpeechTool,
)
from langchain.tools.azure_cognitive_services.text_analytics_health import (
AzureCogsTextAnalyticsHealthTool,
)
__all__ = [
"AzureCogsImageAnalysisTool",
"AzureCogsFormRecognizerTool",
"AzureCogsSpeech2TextTool",
"AzureCogsText2SpeechTool",
"AzureCogsTextAnalyticsHealthTool",
]

View File

@@ -0,0 +1,104 @@
from __future__ import annotations
import logging
from typing import Any, Dict, Optional
from langchain.callbacks.manager import CallbackManagerForToolRun
from langchain.pydantic_v1 import root_validator
from langchain.tools.base import BaseTool
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class AzureCogsTextAnalyticsHealthTool(BaseTool):
"""Tool that queries the Azure Cognitive Services Text Analytics for Health API.
In order to set this up, follow instructions at:
https://learn.microsoft.com/en-us/azure/ai-services/language-service/text-analytics-for-health/quickstart?tabs=windows&pivots=programming-language-python
"""
azure_cogs_key: str = "" #: :meta private:
azure_cogs_endpoint: str = "" #: :meta private:
text_analytics_client: Any #: :meta private:
name: str = "azure_cognitive_services_text_analyics_health"
description: str = (
"A wrapper around Azure Cognitive Services Text Analytics for Health. "
"Useful for when you need to identify entities in healthcare data. "
"Input should be text."
)
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and endpoint exists in environment."""
azure_cogs_key = get_from_dict_or_env(
values, "azure_cogs_key", "AZURE_COGS_KEY"
)
azure_cogs_endpoint = get_from_dict_or_env(
values, "azure_cogs_endpoint", "AZURE_COGS_ENDPOINT"
)
try:
import azure.ai.textanalytics as sdk
from azure.core.credentials import AzureKeyCredential
values["text_analytics_client"] = sdk.TextAnalyticsClient(
endpoint=azure_cogs_endpoint,
credential=AzureKeyCredential(azure_cogs_key),
)
except ImportError:
raise ImportError(
"azure-ai-textanalytics is not installed. "
"Run `pip install azure-ai-textanalytics` to install."
)
return values
def _text_analysis(self, text: str) -> Dict:
poller = self.text_analytics_client.begin_analyze_healthcare_entities(
[{"id": "1", "language": "en", "text": text}]
)
result = poller.result()
res_dict = {}
docs = [doc for doc in result if not doc.is_error]
if docs is not None:
res_dict["entities"] = [
f"{x.text} is a healthcare entity of type {x.category}"
for y in docs
for x in y.entities
]
return res_dict
def _format_text_analysis_result(self, text_analysis_result: Dict) -> str:
formatted_result = []
if "entities" in text_analysis_result:
formatted_result.append(
f"""The text contains the following healthcare entities: {
', '.join(text_analysis_result['entities'])
}""".replace("\n", " ")
)
return "\n".join(formatted_result)
def _run(
self,
query: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
try:
text_analysis_result = self._text_analysis(query)
return self._format_text_analysis_result(text_analysis_result)
except Exception as e:
raise RuntimeError(
f"Error while running AzureCogsTextAnalyticsHealthTool: {e}"
)

View File

@@ -693,6 +693,23 @@ azure-core = ">=1.23.0,<2.0.0"
msrest = ">=0.6.21"
typing-extensions = ">=4.0.1"
[[package]]
name = "azure-ai-textanalytics"
version = "5.3.0"
description = "Microsoft Azure Text Analytics Client Library for Python"
optional = true
python-versions = ">=3.7"
files = [
{file = "azure-ai-textanalytics-5.3.0.zip", hash = "sha256:4f7d067d5bb08422599ca6175510d39b0911c711301647e5f18e904a5027bf58"},
{file = "azure_ai_textanalytics-5.3.0-py3-none-any.whl", hash = "sha256:69bb736d93de81060e9075d42b6f0b92c25be0fb106da5cb6a6d30e772168221"},
]
[package.dependencies]
azure-common = ">=1.1,<2.0"
azure-core = ">=1.24.0,<2.0.0"
isodate = ">=0.6.1,<1.0.0"
typing-extensions = ">=4.0.1"
[[package]]
name = "azure-ai-vision"
version = "0.11.1b1"
@@ -2889,7 +2906,7 @@ files = [
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:0b72b802496cccbd9b31acea72b6f87e7771ccfd7f7927437d592e5c92ed703c"},
{file = "greenlet-3.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:527cd90ba3d8d7ae7dceb06fda619895768a46a1b4e423bdb24c1969823b8362"},
{file = "greenlet-3.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:37f60b3a42d8b5499be910d1267b24355c495064f271cfe74bf28b17b099133c"},
{file = "greenlet-3.0.0-cp311-universal2-macosx_10_9_universal2.whl", hash = "sha256:c3692ecf3fe754c8c0f2c95ff19626584459eab110eaab66413b1e7425cd84e9"},
{file = "greenlet-3.0.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1482fba7fbed96ea7842b5a7fc11d61727e8be75a077e603e8ab49d24e234383"},
{file = "greenlet-3.0.0-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:be557119bf467d37a8099d91fbf11b2de5eb1fd5fc5b91598407574848dc910f"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73b2f1922a39d5d59cc0e597987300df3396b148a9bd10b76a058a2f2772fc04"},
{file = "greenlet-3.0.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1e22c22f7826096ad503e9bb681b05b8c1f5a8138469b255eb91f26a76634f2"},
@@ -2899,7 +2916,6 @@ files = [
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:952256c2bc5b4ee8df8dfc54fc4de330970bf5d79253c863fb5e6761f00dda35"},
{file = "greenlet-3.0.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:269d06fa0f9624455ce08ae0179430eea61085e3cf6457f05982b37fd2cefe17"},
{file = "greenlet-3.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:9adbd8ecf097e34ada8efde9b6fec4dd2a903b1e98037adf72d12993a1c80b51"},
{file = "greenlet-3.0.0-cp312-universal2-macosx_10_9_universal2.whl", hash = "sha256:553d6fb2324e7f4f0899e5ad2c427a4579ed4873f42124beba763f16032959af"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6b5ce7f40f0e2f8b88c28e6691ca6806814157ff05e794cdd161be928550f4c"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ecf94aa539e97a8411b5ea52fc6ccd8371be9550c4041011a091eb8b3ca1d810"},
{file = "greenlet-3.0.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:80dcd3c938cbcac986c5c92779db8e8ce51a89a849c135172c88ecbdc8c056b7"},
@@ -11030,8 +11046,8 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
all = ["O365", "aleph-alpha-client", "amadeus", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-textanalytics", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jinja2", "jq", "lancedb", "langkit", "lark", "librosa", "lxml", "manifest-ml", "marqo", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "python-arango", "pyvespa", "qdrant-client", "rdflib", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-ai-formrecognizer", "azure-ai-textanalytics", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
clarifai = ["clarifai"]
cli = ["typer"]
cohere = ["cohere"]
@@ -11047,4 +11063,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "964150ae1b32757b20d1eafa4cdf46f1dc0273fd8879f2ee3229642339df2cd5"
content-hash = "db9fcf70a90810fa88259ed5b7c8fd878d18c2979282591c2a2fd83a6476ae8a"

View File

@@ -138,6 +138,7 @@ anthropic = {version = "^0.3.11", optional = true}
aiosqlite = {version = "^0.19.0", optional = true}
rspace_client = {version = "^2.5.0", optional = true}
upstash-redis = {version = "^0.15.0", optional = true}
azure-ai-textanalytics = {version = "^5.3.0", optional = true}
google-cloud-documentai = {version = "^2.20.1", optional = true}
fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"}
javelin-sdk = {version = "^0.1.8", optional = true}
@@ -233,6 +234,7 @@ azure = [
"azure-ai-vision",
"azure-cognitiveservices-speech",
"azure-search-documents",
"azure-ai-textanalytics",
]
all = [
"clarifai",
@@ -297,6 +299,7 @@ all = [
"azure-ai-formrecognizer",
"azure-ai-vision",
"azure-cognitiveservices-speech",
"azure-ai-textanalytics",
"momento",
"singlestoredb",
"tigrisdb",

View File

@@ -13,6 +13,7 @@ EXPECTED_ALL = [
"AzureCogsImageAnalysisTool",
"AzureCogsSpeech2TextTool",
"AzureCogsText2SpeechTool",
"AzureCogsTextAnalyticsHealthTool",
"BaseGraphQLTool",
"BaseRequestsTool",
"BaseSQLDatabaseTool",

View File

@@ -14,6 +14,7 @@ _EXPECTED = [
"AzureCogsImageAnalysisTool",
"AzureCogsSpeech2TextTool",
"AzureCogsText2SpeechTool",
"AzureCogsTextAnalyticsHealthTool",
"BaseGraphQLTool",
"BaseRequestsTool",
"BaseSQLDatabaseTool",