diff --git a/libs/community/langchain_community/document_loaders/doc_intelligence.py b/libs/community/langchain_community/document_loaders/doc_intelligence.py index d51fa575604..ee126b12e7b 100644 --- a/libs/community/langchain_community/document_loaders/doc_intelligence.py +++ b/libs/community/langchain_community/document_loaders/doc_intelligence.py @@ -1,4 +1,6 @@ -from typing import Iterator, List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, Iterator, List, Optional from langchain_core.documents import Document @@ -8,6 +10,9 @@ from langchain_community.document_loaders.parsers import ( AzureAIDocumentIntelligenceParser, ) +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + class AzureAIDocumentIntelligenceLoader(BaseLoader): """Load a PDF with Azure Document Intelligence.""" @@ -15,7 +20,7 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): def __init__( self, api_endpoint: str, - api_key: str, + api_key: Optional[str] = None, file_path: Optional[str] = None, url_path: Optional[str] = None, bytes_source: Optional[bytes] = None, @@ -24,6 +29,7 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): mode: str = "markdown", *, analysis_features: Optional[List[str]] = None, + azure_credential: Optional["TokenCredential"] = None, ) -> None: """ Initialize the object for file processing with Azure Document Intelligence @@ -63,6 +69,9 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): List of optional analysis features, each feature should be passed as a str that conforms to the enum `DocumentAnalysisFeature` in `azure-ai-documentintelligence` package. Default value is None. + azure_credential: Optional[TokenCredential] + The credentials to use for DocumentIntelligenceClient construction, when + using credentials other than api_key (like AD). Examples: --------- @@ -79,6 +88,15 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): assert ( file_path is not None or url_path is not None or bytes_source is not None ), "file_path, url_path or bytes_source must be provided" + + assert ( + api_key is not None or azure_credential is not None + ), "Either api_key or azure_credential must be provided." + + assert ( + api_key is None or azure_credential is None + ), "Only one of api_key or azure_credential should be provided." + self.file_path = file_path self.url_path = url_path self.bytes_source = bytes_source @@ -90,6 +108,7 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader): api_model=api_model, mode=mode, analysis_features=analysis_features, + azure_credential=azure_credential, ) def lazy_load( diff --git a/libs/community/langchain_community/document_loaders/parsers/doc_intelligence.py b/libs/community/langchain_community/document_loaders/parsers/doc_intelligence.py index 78ff0223595..3bcbec6d9a4 100644 --- a/libs/community/langchain_community/document_loaders/parsers/doc_intelligence.py +++ b/libs/community/langchain_community/document_loaders/parsers/doc_intelligence.py @@ -1,11 +1,16 @@ +from __future__ import annotations + import logging -from typing import Any, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, Iterator, List, Optional from langchain_core.documents import Document from langchain_community.document_loaders.base import BaseBlobParser from langchain_community.document_loaders.blob_loaders import Blob +if TYPE_CHECKING: + from azure.core.credentials import TokenCredential + logger = logging.getLogger(__name__) @@ -16,17 +21,27 @@ class AzureAIDocumentIntelligenceParser(BaseBlobParser): def __init__( self, api_endpoint: str, - api_key: str, + api_key: Optional[str] = None, api_version: Optional[str] = None, api_model: str = "prebuilt-layout", mode: str = "markdown", analysis_features: Optional[List[str]] = None, + azure_credential: Optional["TokenCredential"] = None, ): from azure.ai.documentintelligence import DocumentIntelligenceClient from azure.ai.documentintelligence.models import DocumentAnalysisFeature from azure.core.credentials import AzureKeyCredential kwargs = {} + + if api_key is None and azure_credential is None: + raise ValueError("Either api_key or azure_credential must be provided.") + + if api_key and azure_credential: + raise ValueError( + "Only one of api_key or azure_credential should be provided." + ) + if api_version is not None: kwargs["api_version"] = api_version @@ -49,7 +64,7 @@ class AzureAIDocumentIntelligenceParser(BaseBlobParser): self.client = DocumentIntelligenceClient( endpoint=api_endpoint, - credential=AzureKeyCredential(api_key), + credential=azure_credential or AzureKeyCredential(api_key), headers={"x-ms-useragent": "langchain-parser/1.0.0"}, features=analysis_features, **kwargs,