mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-18 17:11:25 +00:00
community[minor]: AWS Athena Document Loader (#15625)
- **Description:** Adds the document loader for [AWS Athena](https://aws.amazon.com/athena/), a serverless and interactive analytics service. - **Dependencies:** Added boto3 as a dependency
This commit is contained in:
parent
93da18b667
commit
584b647b96
110
docs/docs/integrations/document_loaders/athena.ipynb
Normal file
110
docs/docs/integrations/document_loaders/athena.ipynb
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "MwTWzDxYgbrR"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"# Athena\n",
|
||||||
|
"\n",
|
||||||
|
"This notebooks goes over how to load documents from AWS Athena"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "F0zaLR3xgWmO"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"! pip install boto3"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "076NLjfngoWJ"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.document_loaders.athena import AthenaLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "XpMRQwU9gu44"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"database_name = \"my_database\"\n",
|
||||||
|
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
|
||||||
|
"query = \"SELECT * FROM my_table\"\n",
|
||||||
|
"profile_name = \"my_profile\"\n",
|
||||||
|
"\n",
|
||||||
|
"loader = AthenaLoader(\n",
|
||||||
|
" query=query,\n",
|
||||||
|
" database=database_name,\n",
|
||||||
|
" s3_output_uri=s3_output_path,\n",
|
||||||
|
" profile_name=profile_name,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"print(documents)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {
|
||||||
|
"id": "5IBapL3ejoEt"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"Example with metadata columns"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"id": "wMx6nI1qjryD"
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"database_name = \"my_database\"\n",
|
||||||
|
"s3_output_path = \"s3://my_bucket/query_results/\"\n",
|
||||||
|
"query = \"SELECT * FROM my_table\"\n",
|
||||||
|
"profile_name = \"my_profile\"\n",
|
||||||
|
"metadata_columns = [\"_row\", \"_created_at\"]\n",
|
||||||
|
"\n",
|
||||||
|
"loader = AthenaLoader(\n",
|
||||||
|
" query=query,\n",
|
||||||
|
" database=database_name,\n",
|
||||||
|
" s3_output_uri=s3_output_path,\n",
|
||||||
|
" profile_name=profile_name,\n",
|
||||||
|
" metadata_columns=metadata_columns,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"documents = loader.load()\n",
|
||||||
|
"print(documents)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"name": "python"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
@ -36,6 +36,7 @@ from langchain_community.document_loaders.assemblyai import (
|
|||||||
)
|
)
|
||||||
from langchain_community.document_loaders.astradb import AstraDBLoader
|
from langchain_community.document_loaders.astradb import AstraDBLoader
|
||||||
from langchain_community.document_loaders.async_html import AsyncHtmlLoader
|
from langchain_community.document_loaders.async_html import AsyncHtmlLoader
|
||||||
|
from langchain_community.document_loaders.athena import AthenaLoader
|
||||||
from langchain_community.document_loaders.azlyrics import AZLyricsLoader
|
from langchain_community.document_loaders.azlyrics import AZLyricsLoader
|
||||||
from langchain_community.document_loaders.azure_ai_data import (
|
from langchain_community.document_loaders.azure_ai_data import (
|
||||||
AzureAIDataLoader,
|
AzureAIDataLoader,
|
||||||
@ -257,6 +258,7 @@ __all__ = [
|
|||||||
"AssemblyAIAudioTranscriptLoader",
|
"AssemblyAIAudioTranscriptLoader",
|
||||||
"AstraDBLoader",
|
"AstraDBLoader",
|
||||||
"AsyncHtmlLoader",
|
"AsyncHtmlLoader",
|
||||||
|
"AthenaLoader",
|
||||||
"AzureAIDataLoader",
|
"AzureAIDataLoader",
|
||||||
"AzureAIDocumentIntelligenceLoader",
|
"AzureAIDocumentIntelligenceLoader",
|
||||||
"AzureBlobStorageContainerLoader",
|
"AzureBlobStorageContainerLoader",
|
||||||
|
167
libs/community/langchain_community/document_loaders/athena.py
Normal file
167
libs/community/langchain_community/document_loaders/athena.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
|
|
||||||
|
class AthenaLoader(BaseLoader):
|
||||||
|
"""Load documents from `AWS Athena`.
|
||||||
|
|
||||||
|
Each document represents one row of the result.
|
||||||
|
- By default, all columns are written into the `page_content` of the document
|
||||||
|
and none into the `metadata` of the document.
|
||||||
|
- If `metadata_columns` are provided then these columns are written
|
||||||
|
into the `metadata` of the document while the rest of the columns
|
||||||
|
are written into the `page_content` of the document.
|
||||||
|
|
||||||
|
To authenticate, the AWS client uses this method to automatically load credentials:
|
||||||
|
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
|
||||||
|
|
||||||
|
If a specific credential profile should be used, you must pass
|
||||||
|
the name of the profile from the ~/.aws/credentials file that is to be used.
|
||||||
|
|
||||||
|
Make sure the credentials / roles used have the required policies to
|
||||||
|
access the Amazon Textract service.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
database: str,
|
||||||
|
s3_output_uri: str,
|
||||||
|
profile_name: str,
|
||||||
|
metadata_columns: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize Athena document loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The query to run in Athena.
|
||||||
|
database: Athena database
|
||||||
|
s3_output_uri: Athena output path
|
||||||
|
metadata_columns: Optional. Columns written to Document `metadata`.
|
||||||
|
"""
|
||||||
|
self.query = query
|
||||||
|
self.database = database
|
||||||
|
self.s3_output_uri = s3_output_uri
|
||||||
|
self.metadata_columns = metadata_columns if metadata_columns is not None else []
|
||||||
|
|
||||||
|
try:
|
||||||
|
import boto3
|
||||||
|
except ImportError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import boto3 python package. "
|
||||||
|
"Please install it with `pip install boto3`."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session = (
|
||||||
|
boto3.Session(profile_name=profile_name)
|
||||||
|
if profile_name is not None
|
||||||
|
else boto3.Session()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not load credentials to authenticate with AWS client. "
|
||||||
|
"Please check that credentials in the specified "
|
||||||
|
"profile name are valid."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
self.athena_client = session.client("athena")
|
||||||
|
self.s3_client = session.client("s3")
|
||||||
|
|
||||||
|
def _execute_query(self) -> List[Dict[str, Any]]:
|
||||||
|
response = self.athena_client.start_query_execution(
|
||||||
|
QueryString=self.query,
|
||||||
|
QueryExecutionContext={"Database": self.database},
|
||||||
|
ResultConfiguration={"OutputLocation": self.s3_output_uri},
|
||||||
|
)
|
||||||
|
query_execution_id = response["QueryExecutionId"]
|
||||||
|
print(f"Query : {self.query}")
|
||||||
|
while True:
|
||||||
|
response = self.athena_client.get_query_execution(
|
||||||
|
QueryExecutionId=query_execution_id
|
||||||
|
)
|
||||||
|
state = response["QueryExecution"]["Status"]["State"]
|
||||||
|
if state == "SUCCEEDED":
|
||||||
|
print(f"State : {state}")
|
||||||
|
break
|
||||||
|
elif state == "FAILED":
|
||||||
|
resp_status = response["QueryExecution"]["Status"]
|
||||||
|
state_change_reason = resp_status["StateChangeReason"]
|
||||||
|
err = f"Query Failed: {state_change_reason}"
|
||||||
|
raise Exception(err)
|
||||||
|
elif state == "CANCELLED":
|
||||||
|
raise Exception("Query was cancelled by the user.")
|
||||||
|
else:
|
||||||
|
print(f"State : {state}")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
result_set = self._get_result_set(query_execution_id)
|
||||||
|
return json.loads(result_set.to_json(orient="records"))
|
||||||
|
|
||||||
|
def _remove_suffix(self, input_string: str, suffix: str) -> str:
|
||||||
|
if suffix and input_string.endswith(suffix):
|
||||||
|
return input_string[: -len(suffix)]
|
||||||
|
return input_string
|
||||||
|
|
||||||
|
def _remove_prefix(self, input_string: str, suffix: str) -> str:
|
||||||
|
if suffix and input_string.startswith(suffix):
|
||||||
|
return input_string[len(suffix) :]
|
||||||
|
return input_string
|
||||||
|
|
||||||
|
def _get_result_set(self, query_execution_id: str) -> Any:
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"Could not import pandas python package. "
|
||||||
|
"Please install it with `pip install pandas`."
|
||||||
|
)
|
||||||
|
|
||||||
|
output_uri = self.s3_output_uri
|
||||||
|
tokens = self._remove_prefix(
|
||||||
|
self._remove_suffix(output_uri, "/"), "s3://"
|
||||||
|
).split("/")
|
||||||
|
bucket = tokens[0]
|
||||||
|
key = "/".join(tokens[1:]) + "/" + query_execution_id + ".csv"
|
||||||
|
|
||||||
|
obj = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||||
|
df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8")
|
||||||
|
return df
|
||||||
|
|
||||||
|
def _get_columns(
|
||||||
|
self, query_result: List[Dict[str, Any]]
|
||||||
|
) -> Tuple[List[str], List[str]]:
|
||||||
|
content_columns = []
|
||||||
|
metadata_columns = []
|
||||||
|
all_columns = list(query_result[0].keys())
|
||||||
|
for key in all_columns:
|
||||||
|
if key in self.metadata_columns:
|
||||||
|
metadata_columns.append(key)
|
||||||
|
else:
|
||||||
|
content_columns.append(key)
|
||||||
|
|
||||||
|
return content_columns, metadata_columns
|
||||||
|
|
||||||
|
def lazy_load(self) -> Iterator[Document]:
|
||||||
|
query_result = self._execute_query()
|
||||||
|
content_columns, metadata_columns = self._get_columns(query_result)
|
||||||
|
for row in query_result:
|
||||||
|
page_content = "\n".join(
|
||||||
|
f"{k}: {v}" for k, v in row.items() if k in content_columns
|
||||||
|
)
|
||||||
|
metadata = {
|
||||||
|
k: v for k, v in row.items() if k in metadata_columns and v is not None
|
||||||
|
}
|
||||||
|
doc = Document(page_content=page_content, metadata=metadata)
|
||||||
|
yield doc
|
||||||
|
|
||||||
|
def load(self) -> List[Document]:
|
||||||
|
"""Load data into document objects."""
|
||||||
|
return list(self.lazy_load())
|
@ -23,6 +23,7 @@ EXPECTED_ALL = [
|
|||||||
"AssemblyAIAudioTranscriptLoader",
|
"AssemblyAIAudioTranscriptLoader",
|
||||||
"AstraDBLoader",
|
"AstraDBLoader",
|
||||||
"AsyncHtmlLoader",
|
"AsyncHtmlLoader",
|
||||||
|
"AthenaLoader",
|
||||||
"AzureAIDataLoader",
|
"AzureAIDataLoader",
|
||||||
"AzureAIDocumentIntelligenceLoader",
|
"AzureAIDocumentIntelligenceLoader",
|
||||||
"AzureBlobStorageContainerLoader",
|
"AzureBlobStorageContainerLoader",
|
||||||
|
Loading…
Reference in New Issue
Block a user