mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 06:14:37 +00:00
community: Refactor PebbloSafeLoader (#25582)
**Refactor PebbloSafeLoader** - Created `APIWrapper` and moved API logic into it. - Moved helper functions to the utility file. - Created smaller functions and methods for better readability. - Properly read environment variables. - Removed unused code. **Issue:** NA **Dependencies:** NA **tests**: Updated
This commit is contained in:
parent
5e3a321f71
commit
1f1679e960
@ -1,31 +1,25 @@
|
|||||||
"""Pebblo's safe dataloader is a wrapper for document loaders"""
|
"""Pebblo's safe dataloader is a wrapper for document loaders"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from http import HTTPStatus
|
from typing import Dict, Iterator, List, Optional
|
||||||
from typing import Any, Dict, Iterator, List, Optional
|
|
||||||
|
|
||||||
import requests # type: ignore
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
from langchain_community.utilities.pebblo import (
|
from langchain_community.utilities.pebblo import (
|
||||||
APP_DISCOVER_URL,
|
|
||||||
BATCH_SIZE_BYTES,
|
BATCH_SIZE_BYTES,
|
||||||
CLASSIFIER_URL,
|
|
||||||
LOADER_DOC_URL,
|
|
||||||
PEBBLO_CLOUD_URL,
|
|
||||||
PLUGIN_VERSION,
|
PLUGIN_VERSION,
|
||||||
App,
|
App,
|
||||||
Doc,
|
|
||||||
IndexedDocument,
|
IndexedDocument,
|
||||||
|
PebbloLoaderAPIWrapper,
|
||||||
generate_size_based_batches,
|
generate_size_based_batches,
|
||||||
get_full_path,
|
get_full_path,
|
||||||
get_loader_full_path,
|
get_loader_full_path,
|
||||||
get_loader_type,
|
get_loader_type,
|
||||||
get_runtime,
|
get_runtime,
|
||||||
|
get_source_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -37,7 +31,6 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_discover_sent: bool = False
|
_discover_sent: bool = False
|
||||||
_loader_sent: bool = False
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -54,22 +47,17 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
if not name or not isinstance(name, str):
|
if not name or not isinstance(name, str):
|
||||||
raise NameError("Must specify a valid name.")
|
raise NameError("Must specify a valid name.")
|
||||||
self.app_name = name
|
self.app_name = name
|
||||||
self.api_key = os.environ.get("PEBBLO_API_KEY") or api_key
|
|
||||||
self.load_id = str(uuid.uuid4())
|
self.load_id = str(uuid.uuid4())
|
||||||
self.loader = langchain_loader
|
self.loader = langchain_loader
|
||||||
self.load_semantic = os.environ.get("PEBBLO_LOAD_SEMANTIC") or load_semantic
|
self.load_semantic = os.environ.get("PEBBLO_LOAD_SEMANTIC") or load_semantic
|
||||||
self.owner = owner
|
self.owner = owner
|
||||||
self.description = description
|
self.description = description
|
||||||
self.source_path = get_loader_full_path(self.loader)
|
self.source_path = get_loader_full_path(self.loader)
|
||||||
self.source_owner = PebbloSafeLoader.get_file_owner_from_path(self.source_path)
|
|
||||||
self.docs: List[Document] = []
|
self.docs: List[Document] = []
|
||||||
self.docs_with_id: List[IndexedDocument] = []
|
self.docs_with_id: List[IndexedDocument] = []
|
||||||
loader_name = str(type(self.loader)).split(".")[-1].split("'")[0]
|
loader_name = str(type(self.loader)).split(".")[-1].split("'")[0]
|
||||||
self.source_type = get_loader_type(loader_name)
|
self.source_type = get_loader_type(loader_name)
|
||||||
self.source_path_size = self.get_source_size(self.source_path)
|
self.source_path_size = get_source_size(self.source_path)
|
||||||
self.source_aggregate_size = 0
|
|
||||||
self.classifier_url = classifier_url or CLASSIFIER_URL
|
|
||||||
self.classifier_location = classifier_location
|
|
||||||
self.batch_size = BATCH_SIZE_BYTES
|
self.batch_size = BATCH_SIZE_BYTES
|
||||||
self.loader_details = {
|
self.loader_details = {
|
||||||
"loader": loader_name,
|
"loader": loader_name,
|
||||||
@ -83,7 +71,13 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
}
|
}
|
||||||
# generate app
|
# generate app
|
||||||
self.app = self._get_app_details()
|
self.app = self._get_app_details()
|
||||||
self._send_discover()
|
# initialize Pebblo Loader API client
|
||||||
|
self.pb_client = PebbloLoaderAPIWrapper(
|
||||||
|
api_key=api_key,
|
||||||
|
classifier_location=classifier_location,
|
||||||
|
classifier_url=classifier_url,
|
||||||
|
)
|
||||||
|
self.pb_client.send_loader_discover(self.app)
|
||||||
|
|
||||||
def load(self) -> List[Document]:
|
def load(self) -> List[Document]:
|
||||||
"""Load Documents.
|
"""Load Documents.
|
||||||
@ -113,7 +107,12 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
is_last_batch: bool = i == total_batches - 1
|
is_last_batch: bool = i == total_batches - 1
|
||||||
self.docs = batch
|
self.docs = batch
|
||||||
self.docs_with_id = self._index_docs()
|
self.docs_with_id = self._index_docs()
|
||||||
classified_docs = self._classify_doc(loading_end=is_last_batch)
|
classified_docs = self.pb_client.classify_documents(
|
||||||
|
self.docs_with_id,
|
||||||
|
self.app,
|
||||||
|
self.loader_details,
|
||||||
|
loading_end=is_last_batch,
|
||||||
|
)
|
||||||
self._add_pebblo_specific_metadata(classified_docs)
|
self._add_pebblo_specific_metadata(classified_docs)
|
||||||
if self.load_semantic:
|
if self.load_semantic:
|
||||||
batch_processed_docs = self._add_semantic_to_docs(classified_docs)
|
batch_processed_docs = self._add_semantic_to_docs(classified_docs)
|
||||||
@ -147,7 +146,9 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
break
|
break
|
||||||
self.docs = list((doc,))
|
self.docs = list((doc,))
|
||||||
self.docs_with_id = self._index_docs()
|
self.docs_with_id = self._index_docs()
|
||||||
classified_doc = self._classify_doc()
|
classified_doc = self.pb_client.classify_documents(
|
||||||
|
self.docs_with_id, self.app, self.loader_details
|
||||||
|
)
|
||||||
self._add_pebblo_specific_metadata(classified_doc)
|
self._add_pebblo_specific_metadata(classified_doc)
|
||||||
if self.load_semantic:
|
if self.load_semantic:
|
||||||
self.docs = self._add_semantic_to_docs(classified_doc)
|
self.docs = self._add_semantic_to_docs(classified_doc)
|
||||||
@ -159,263 +160,6 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
def set_discover_sent(cls) -> None:
|
def set_discover_sent(cls) -> None:
|
||||||
cls._discover_sent = True
|
cls._discover_sent = True
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def set_loader_sent(cls) -> None:
|
|
||||||
cls._loader_sent = True
|
|
||||||
|
|
||||||
def _classify_doc(self, loading_end: bool = False) -> dict:
|
|
||||||
"""Send documents fetched from loader to pebblo-server. Then send
|
|
||||||
classified documents to Daxa cloud(If api_key is present). Internal method.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
|
|
||||||
loading_end (bool, optional): Flag indicating the halt of data
|
|
||||||
loading by loader. Defaults to False.
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
if loading_end is True:
|
|
||||||
PebbloSafeLoader.set_loader_sent()
|
|
||||||
doc_content = [doc.dict() for doc in self.docs_with_id]
|
|
||||||
docs = []
|
|
||||||
for doc in doc_content:
|
|
||||||
doc_metadata = doc.get("metadata", {})
|
|
||||||
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
|
|
||||||
doc_source_path = get_full_path(
|
|
||||||
doc_metadata.get(
|
|
||||||
"full_path", doc_metadata.get("source", self.source_path)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
doc_source_owner = doc_metadata.get(
|
|
||||||
"owner", PebbloSafeLoader.get_file_owner_from_path(doc_source_path)
|
|
||||||
)
|
|
||||||
doc_source_size = doc_metadata.get(
|
|
||||||
"size", self.get_source_size(doc_source_path)
|
|
||||||
)
|
|
||||||
page_content = str(doc.get("page_content"))
|
|
||||||
page_content_size = self.calculate_content_size(page_content)
|
|
||||||
self.source_aggregate_size += page_content_size
|
|
||||||
doc_id = doc.get("pb_id", None) or 0
|
|
||||||
docs.append(
|
|
||||||
{
|
|
||||||
"doc": page_content,
|
|
||||||
"source_path": doc_source_path,
|
|
||||||
"pb_id": doc_id,
|
|
||||||
"last_modified": doc.get("metadata", {}).get("last_modified"),
|
|
||||||
"file_owner": doc_source_owner,
|
|
||||||
**(
|
|
||||||
{"authorized_identities": doc_authorized_identities}
|
|
||||||
if doc_authorized_identities
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
**(
|
|
||||||
{"source_path_size": doc_source_size}
|
|
||||||
if doc_source_size is not None
|
|
||||||
else {}
|
|
||||||
),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
payload: Dict[str, Any] = {
|
|
||||||
"name": self.app_name,
|
|
||||||
"owner": self.owner,
|
|
||||||
"docs": docs,
|
|
||||||
"plugin_version": PLUGIN_VERSION,
|
|
||||||
"load_id": self.load_id,
|
|
||||||
"loader_details": self.loader_details,
|
|
||||||
"loading_end": "false",
|
|
||||||
"source_owner": self.source_owner,
|
|
||||||
"classifier_location": self.classifier_location,
|
|
||||||
}
|
|
||||||
if loading_end is True:
|
|
||||||
payload["loading_end"] = "true"
|
|
||||||
if "loader_details" in payload:
|
|
||||||
payload["loader_details"]["source_aggregate_size"] = (
|
|
||||||
self.source_aggregate_size
|
|
||||||
)
|
|
||||||
payload = Doc(**payload).dict(exclude_unset=True)
|
|
||||||
classified_docs = {}
|
|
||||||
# Raw payload to be sent to classifier
|
|
||||||
if self.classifier_location == "local":
|
|
||||||
load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}"
|
|
||||||
try:
|
|
||||||
pebblo_resp = requests.post(
|
|
||||||
load_doc_url, headers=headers, json=payload, timeout=300
|
|
||||||
)
|
|
||||||
|
|
||||||
# Updating the structure of pebblo response docs for efficient searching
|
|
||||||
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
|
|
||||||
classified_docs.update({classified_doc["pb_id"]: classified_doc})
|
|
||||||
if pebblo_resp.status_code not in [
|
|
||||||
HTTPStatus.OK,
|
|
||||||
HTTPStatus.BAD_GATEWAY,
|
|
||||||
]:
|
|
||||||
logger.warning(
|
|
||||||
"Received unexpected HTTP response code: %s",
|
|
||||||
pebblo_resp.status_code,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"send_loader_doc[local]: request url %s, body %s len %s\
|
|
||||||
response status %s body %s",
|
|
||||||
pebblo_resp.request.url,
|
|
||||||
str(pebblo_resp.request.body),
|
|
||||||
str(
|
|
||||||
len(
|
|
||||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
|
||||||
)
|
|
||||||
),
|
|
||||||
str(pebblo_resp.status_code),
|
|
||||||
pebblo_resp.json(),
|
|
||||||
)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
logger.warning("Unable to reach pebblo server.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("An Exception caught in _send_loader_doc: local %s", e)
|
|
||||||
|
|
||||||
if self.api_key:
|
|
||||||
if self.classifier_location == "local":
|
|
||||||
docs = payload["docs"]
|
|
||||||
for doc_data in docs:
|
|
||||||
classified_data = classified_docs.get(doc_data["pb_id"], {})
|
|
||||||
doc_data.update(
|
|
||||||
{
|
|
||||||
"pb_checksum": classified_data.get("pb_checksum", None),
|
|
||||||
"loader_source_path": classified_data.get(
|
|
||||||
"loader_source_path", None
|
|
||||||
),
|
|
||||||
"entities": classified_data.get("entities", {}),
|
|
||||||
"topics": classified_data.get("topics", {}),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
doc_data.pop("doc")
|
|
||||||
|
|
||||||
headers.update({"x-api-key": self.api_key})
|
|
||||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}"
|
|
||||||
try:
|
|
||||||
pebblo_cloud_response = requests.post(
|
|
||||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"send_loader_doc[cloud]: request url %s, body %s len %s\
|
|
||||||
response status %s body %s",
|
|
||||||
pebblo_cloud_response.request.url,
|
|
||||||
str(pebblo_cloud_response.request.body),
|
|
||||||
str(
|
|
||||||
len(
|
|
||||||
pebblo_cloud_response.request.body
|
|
||||||
if pebblo_cloud_response.request.body
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
),
|
|
||||||
str(pebblo_cloud_response.status_code),
|
|
||||||
pebblo_cloud_response.json(),
|
|
||||||
)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
logger.warning("Unable to reach Pebblo cloud server.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("An Exception caught in _send_loader_doc: cloud %s", e)
|
|
||||||
elif self.classifier_location == "pebblo-cloud":
|
|
||||||
logger.warning("API key is missing for sending docs to Pebblo cloud.")
|
|
||||||
raise NameError("API key is missing for sending docs to Pebblo cloud.")
|
|
||||||
|
|
||||||
return classified_docs
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def calculate_content_size(page_content: str) -> int:
|
|
||||||
"""Calculate the content size in bytes:
|
|
||||||
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
|
|
||||||
- Get the length of the encoded bytes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
page_content (str): Data string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Size of string in bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Encode the content to bytes using UTF-8
|
|
||||||
encoded_content = page_content.encode("utf-8")
|
|
||||||
size = len(encoded_content)
|
|
||||||
return size
|
|
||||||
|
|
||||||
def _send_discover(self) -> None:
|
|
||||||
"""Send app discovery payload to pebblo-server. Internal method."""
|
|
||||||
pebblo_resp = None
|
|
||||||
headers = {
|
|
||||||
"Accept": "application/json",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
payload = self.app.dict(exclude_unset=True)
|
|
||||||
# Raw discover payload to be sent to classifier
|
|
||||||
if self.classifier_location == "local":
|
|
||||||
app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}"
|
|
||||||
try:
|
|
||||||
pebblo_resp = requests.post(
|
|
||||||
app_discover_url, headers=headers, json=payload, timeout=20
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"send_discover[local]: request url %s, body %s len %s\
|
|
||||||
response status %s body %s",
|
|
||||||
pebblo_resp.request.url,
|
|
||||||
str(pebblo_resp.request.body),
|
|
||||||
str(
|
|
||||||
len(
|
|
||||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
|
||||||
)
|
|
||||||
),
|
|
||||||
str(pebblo_resp.status_code),
|
|
||||||
pebblo_resp.json(),
|
|
||||||
)
|
|
||||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
|
||||||
PebbloSafeLoader.set_discover_sent()
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Received unexpected HTTP response code:\
|
|
||||||
{pebblo_resp.status_code}"
|
|
||||||
)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
logger.warning("Unable to reach pebblo server.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
|
||||||
|
|
||||||
if self.api_key:
|
|
||||||
try:
|
|
||||||
headers.update({"x-api-key": self.api_key})
|
|
||||||
# If the pebblo_resp is None,
|
|
||||||
# then the pebblo server version is not available
|
|
||||||
if pebblo_resp:
|
|
||||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
|
||||||
"pebblo_server_version"
|
|
||||||
)
|
|
||||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
|
||||||
|
|
||||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
|
||||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
|
||||||
pebblo_cloud_response = requests.post(
|
|
||||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"send_discover[cloud]: request url %s, body %s len %s\
|
|
||||||
response status %s body %s",
|
|
||||||
pebblo_cloud_response.request.url,
|
|
||||||
str(pebblo_cloud_response.request.body),
|
|
||||||
str(
|
|
||||||
len(
|
|
||||||
pebblo_cloud_response.request.body
|
|
||||||
if pebblo_cloud_response.request.body
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
),
|
|
||||||
str(pebblo_cloud_response.status_code),
|
|
||||||
pebblo_cloud_response.json(),
|
|
||||||
)
|
|
||||||
except requests.exceptions.RequestException:
|
|
||||||
logger.warning("Unable to reach Pebblo cloud server.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning("An Exception caught in _send_discover: cloud %s", e)
|
|
||||||
|
|
||||||
def _get_app_details(self) -> App:
|
def _get_app_details(self) -> App:
|
||||||
"""Fetch app details. Internal method.
|
"""Fetch app details. Internal method.
|
||||||
|
|
||||||
@ -434,49 +178,6 @@ class PebbloSafeLoader(BaseLoader):
|
|||||||
)
|
)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_file_owner_from_path(file_path: str) -> str:
|
|
||||||
"""Fetch owner of local file path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path (str): Local file path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Name of owner.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import pwd
|
|
||||||
|
|
||||||
file_owner_uid = os.stat(file_path).st_uid
|
|
||||||
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
|
|
||||||
except Exception:
|
|
||||||
file_owner_name = "unknown"
|
|
||||||
return file_owner_name
|
|
||||||
|
|
||||||
def get_source_size(self, source_path: str) -> int:
|
|
||||||
"""Fetch size of source path. Source can be a directory or a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
source_path (str): Local path of data source.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Source size in bytes.
|
|
||||||
"""
|
|
||||||
if not source_path:
|
|
||||||
return 0
|
|
||||||
size = 0
|
|
||||||
if os.path.isfile(source_path):
|
|
||||||
size = os.path.getsize(source_path)
|
|
||||||
elif os.path.isdir(source_path):
|
|
||||||
total_size = 0
|
|
||||||
for dirpath, _, filenames in os.walk(source_path):
|
|
||||||
for f in filenames:
|
|
||||||
fp = os.path.join(dirpath, f)
|
|
||||||
if not os.path.islink(fp):
|
|
||||||
total_size += os.path.getsize(fp)
|
|
||||||
size = total_size
|
|
||||||
return size
|
|
||||||
|
|
||||||
def _index_docs(self) -> List[IndexedDocument]:
|
def _index_docs(self) -> List[IndexedDocument]:
|
||||||
"""
|
"""
|
||||||
Indexes the documents and returns a list of IndexedDocument objects.
|
Indexes the documents and returns a list of IndexedDocument objects.
|
||||||
|
@ -1,25 +1,29 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import platform
|
import platform
|
||||||
from typing import List, Optional, Tuple
|
from enum import Enum
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.env import get_runtime_environment
|
from langchain_core.env import get_runtime_environment
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
|
from requests import Response, request
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
from langchain_community.document_loaders.base import BaseLoader
|
from langchain_community.document_loaders.base import BaseLoader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PLUGIN_VERSION = "0.1.1"
|
PLUGIN_VERSION = "0.1.1"
|
||||||
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
|
||||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
|
||||||
|
|
||||||
LOADER_DOC_URL = "/v1/loader/doc"
|
_DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
|
||||||
APP_DISCOVER_URL = "/v1/app/discover"
|
_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
|
||||||
BATCH_SIZE_BYTES = 100 * 1024 # 100 KB
|
BATCH_SIZE_BYTES = 100 * 1024 # 100 KB
|
||||||
|
|
||||||
# Supported loaders for Pebblo safe data loading
|
# Supported loaders for Pebblo safe data loading
|
||||||
@ -59,9 +63,15 @@ LOADER_TYPE_MAPPING = {
|
|||||||
"cloud-folder": cloud_folder,
|
"cloud-folder": cloud_folder,
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORTED_LOADERS = (*file_loader, *dir_loader, *in_memory)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
class Routes(str, Enum):
|
||||||
|
"""Routes available for the Pebblo API as enumerator."""
|
||||||
|
|
||||||
|
loader_doc = "/v1/loader/doc"
|
||||||
|
loader_app_discover = "/v1/app/discover"
|
||||||
|
retrieval_app_discover = "/v1/app/discover"
|
||||||
|
prompt = "/v1/prompt"
|
||||||
|
prompt_governance = "/v1/prompt/governance"
|
||||||
|
|
||||||
|
|
||||||
class IndexedDocument(Document):
|
class IndexedDocument(Document):
|
||||||
@ -342,3 +352,386 @@ def generate_size_based_batches(
|
|||||||
batches.append(current_batch)
|
batches.append(current_batch)
|
||||||
|
|
||||||
return batches
|
return batches
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_owner_from_path(file_path: str) -> str:
|
||||||
|
"""Fetch owner of local file path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (str): Local file path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Name of owner.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pwd
|
||||||
|
|
||||||
|
file_owner_uid = os.stat(file_path).st_uid
|
||||||
|
file_owner_name = pwd.getpwuid(file_owner_uid).pw_name
|
||||||
|
except Exception:
|
||||||
|
file_owner_name = "unknown"
|
||||||
|
return file_owner_name
|
||||||
|
|
||||||
|
|
||||||
|
def get_source_size(source_path: str) -> int:
|
||||||
|
"""Fetch size of source path. Source can be a directory or a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path (str): Local path of data source.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Source size in bytes.
|
||||||
|
"""
|
||||||
|
if not source_path:
|
||||||
|
return 0
|
||||||
|
size = 0
|
||||||
|
if os.path.isfile(source_path):
|
||||||
|
size = os.path.getsize(source_path)
|
||||||
|
elif os.path.isdir(source_path):
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, _, filenames in os.walk(source_path):
|
||||||
|
for f in filenames:
|
||||||
|
fp = os.path.join(dirpath, f)
|
||||||
|
if not os.path.islink(fp):
|
||||||
|
total_size += os.path.getsize(fp)
|
||||||
|
size = total_size
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_content_size(data: str) -> int:
|
||||||
|
"""Calculate the content size in bytes:
|
||||||
|
- Encode the string to bytes using a specific encoding (e.g., UTF-8)
|
||||||
|
- Get the length of the encoded bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (str): Data string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Size of string in bytes.
|
||||||
|
"""
|
||||||
|
encoded_content = data.encode("utf-8")
|
||||||
|
size = len(encoded_content)
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
class PebbloLoaderAPIWrapper(BaseModel):
|
||||||
|
"""Wrapper for Pebblo Loader API."""
|
||||||
|
|
||||||
|
api_key: Optional[str] # Use SecretStr
|
||||||
|
"""API key for Pebblo Cloud"""
|
||||||
|
classifier_location: str = "local"
|
||||||
|
"""Location of the classifier, local or cloud. Defaults to 'local'"""
|
||||||
|
classifier_url: Optional[str]
|
||||||
|
"""URL of the Pebblo Classifier"""
|
||||||
|
cloud_url: Optional[str]
|
||||||
|
"""URL of the Pebblo Cloud"""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: Any):
|
||||||
|
"""Validate that api key in environment."""
|
||||||
|
kwargs["api_key"] = get_from_dict_or_env(
|
||||||
|
kwargs, "api_key", "PEBBLO_API_KEY", ""
|
||||||
|
)
|
||||||
|
kwargs["classifier_url"] = get_from_dict_or_env(
|
||||||
|
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
|
||||||
|
)
|
||||||
|
kwargs["cloud_url"] = get_from_dict_or_env(
|
||||||
|
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
|
||||||
|
)
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def send_loader_discover(self, app: App) -> None:
|
||||||
|
"""
|
||||||
|
Send app discovery request to Pebblo server & cloud.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (App): App instance to be discovered.
|
||||||
|
"""
|
||||||
|
pebblo_resp = None
|
||||||
|
payload = app.dict(exclude_unset=True)
|
||||||
|
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
# Send app details to local classifier
|
||||||
|
headers = self._make_headers()
|
||||||
|
app_discover_url = f"{self.classifier_url}{Routes.loader_app_discover}"
|
||||||
|
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
# Send app details to Pebblo cloud if api_key is present
|
||||||
|
headers = self._make_headers(cloud_request=True)
|
||||||
|
if pebblo_resp:
|
||||||
|
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||||
|
"pebblo_server_version"
|
||||||
|
)
|
||||||
|
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||||
|
|
||||||
|
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||||
|
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_app_discover}"
|
||||||
|
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||||
|
|
||||||
|
def classify_documents(
|
||||||
|
self,
|
||||||
|
docs_with_id: List[IndexedDocument],
|
||||||
|
app: App,
|
||||||
|
loader_details: dict,
|
||||||
|
loading_end: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Send documents to Pebblo server for classification.
|
||||||
|
Then send classified documents to Daxa cloud(If api_key is present).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs_with_id (List[IndexedDocument]): List of documents to be classified.
|
||||||
|
app (App): App instance.
|
||||||
|
loader_details (dict): Loader details.
|
||||||
|
loading_end (bool): Boolean, indicating the halt of data loading by loader.
|
||||||
|
"""
|
||||||
|
source_path = loader_details.get("source_path", "")
|
||||||
|
source_owner = get_file_owner_from_path(source_path)
|
||||||
|
# Prepare docs for classification
|
||||||
|
docs, source_aggregate_size = self.prepare_docs_for_classification(
|
||||||
|
docs_with_id, source_path
|
||||||
|
)
|
||||||
|
# Build payload for classification
|
||||||
|
payload = self.build_classification_payload(
|
||||||
|
app, docs, loader_details, source_owner, source_aggregate_size, loading_end
|
||||||
|
)
|
||||||
|
|
||||||
|
classified_docs = {}
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
# Send docs to local classifier
|
||||||
|
headers = self._make_headers()
|
||||||
|
load_doc_url = f"{self.classifier_url}{Routes.loader_doc}"
|
||||||
|
try:
|
||||||
|
pebblo_resp = self.make_request(
|
||||||
|
"POST", load_doc_url, headers, payload, 300
|
||||||
|
)
|
||||||
|
|
||||||
|
if pebblo_resp:
|
||||||
|
# Updating structure of pebblo response docs for efficient searching
|
||||||
|
for classified_doc in json.loads(pebblo_resp.text).get("docs", []):
|
||||||
|
classified_docs.update(
|
||||||
|
{classified_doc["pb_id"]: classified_doc}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in classify_documents: local %s", e)
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
# Send docs to Pebblo cloud if api_key is present
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
# If local classifier is used add the classified information
|
||||||
|
# and remove doc content
|
||||||
|
self.update_doc_data(payload["docs"], classified_docs)
|
||||||
|
self.send_docs_to_pebblo_cloud(payload)
|
||||||
|
elif self.classifier_location == "pebblo-cloud":
|
||||||
|
logger.warning("API key is missing for sending docs to Pebblo cloud.")
|
||||||
|
raise NameError("API key is missing for sending docs to Pebblo cloud.")
|
||||||
|
|
||||||
|
return classified_docs
|
||||||
|
|
||||||
|
def send_docs_to_pebblo_cloud(self, payload: dict) -> None:
|
||||||
|
"""
|
||||||
|
Send documents to Pebblo cloud.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload (dict): The payload containing documents to be sent.
|
||||||
|
"""
|
||||||
|
headers = self._make_headers(cloud_request=True)
|
||||||
|
pebblo_cloud_url = f"{self.cloud_url}{Routes.loader_doc}"
|
||||||
|
try:
|
||||||
|
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in classify_documents: cloud %s", e)
|
||||||
|
|
||||||
|
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Generate headers for the request.
|
||||||
|
|
||||||
|
args:
|
||||||
|
cloud_request (bool): flag indicating whether the request is for Pebblo
|
||||||
|
cloud.
|
||||||
|
returns:
|
||||||
|
dict: Headers for the request.
|
||||||
|
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
if cloud_request:
|
||||||
|
# Add API key for Pebblo cloud request
|
||||||
|
if self.api_key:
|
||||||
|
headers.update({"x-api-key": self.api_key})
|
||||||
|
else:
|
||||||
|
logger.warning("API key is missing for Pebblo cloud request.")
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def build_classification_payload(
|
||||||
|
self,
|
||||||
|
app: App,
|
||||||
|
docs: List[dict],
|
||||||
|
loader_details: dict,
|
||||||
|
source_owner: str,
|
||||||
|
source_aggregate_size: int,
|
||||||
|
loading_end: bool,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Build the payload for document classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app (App): App instance.
|
||||||
|
docs (List[dict]): List of documents to be classified.
|
||||||
|
loader_details (dict): Loader details.
|
||||||
|
source_owner (str): Owner of the source.
|
||||||
|
source_aggregate_size (int): Aggregate size of the source.
|
||||||
|
loading_end (bool): Boolean indicating the halt of data loading by loader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Payload for document classification.
|
||||||
|
"""
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"name": app.name,
|
||||||
|
"owner": app.owner,
|
||||||
|
"docs": docs,
|
||||||
|
"plugin_version": PLUGIN_VERSION,
|
||||||
|
"load_id": app.load_id,
|
||||||
|
"loader_details": loader_details,
|
||||||
|
"loading_end": "false",
|
||||||
|
"source_owner": source_owner,
|
||||||
|
"classifier_location": self.classifier_location,
|
||||||
|
}
|
||||||
|
if loading_end is True:
|
||||||
|
payload["loading_end"] = "true"
|
||||||
|
if "loader_details" in payload:
|
||||||
|
payload["loader_details"]["source_aggregate_size"] = (
|
||||||
|
source_aggregate_size
|
||||||
|
)
|
||||||
|
payload = Doc(**payload).dict(exclude_unset=True)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_request(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
payload: Optional[dict] = None,
|
||||||
|
timeout: int = 20,
|
||||||
|
) -> Optional[Response]:
|
||||||
|
"""
|
||||||
|
Make a request to the Pebblo API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||||
|
url (str): URL for the request.
|
||||||
|
headers (dict): Headers for the request.
|
||||||
|
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||||
|
timeout (int): Timeout for the request in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[Response]: Response object if the request is successful.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = request(
|
||||||
|
method=method, url=url, headers=headers, json=payload, timeout=timeout
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Request: method %s, url %s, len %s response status %s",
|
||||||
|
method,
|
||||||
|
response.request.url,
|
||||||
|
str(len(response.request.body if response.request.body else [])),
|
||||||
|
str(response.status_code),
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||||
|
logger.warning(f"Pebblo Server: Error {response.status_code}")
|
||||||
|
elif response.status_code >= HTTPStatus.BAD_REQUEST:
|
||||||
|
logger.warning(f"Pebblo received an invalid payload: {response.text}")
|
||||||
|
elif response.status_code != HTTPStatus.OK:
|
||||||
|
logger.warning(
|
||||||
|
f"Pebblo returned an unexpected response code: "
|
||||||
|
f"{response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
except RequestException:
|
||||||
|
logger.warning("Unable to reach server %s", url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in make_request: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_docs_for_classification(
|
||||||
|
docs_with_id: List[IndexedDocument], source_path: str
|
||||||
|
) -> Tuple[List[dict], int]:
|
||||||
|
"""
|
||||||
|
Prepare documents for classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs_with_id (List[IndexedDocument]): List of documents to be classified.
|
||||||
|
source_path (str): Source path of the documents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[dict], int]: Documents and the aggregate size of the source.
|
||||||
|
"""
|
||||||
|
docs = []
|
||||||
|
source_aggregate_size = 0
|
||||||
|
doc_content = [doc.dict() for doc in docs_with_id]
|
||||||
|
for doc in doc_content:
|
||||||
|
doc_metadata = doc.get("metadata", {})
|
||||||
|
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
|
||||||
|
doc_source_path = get_full_path(
|
||||||
|
doc_metadata.get(
|
||||||
|
"full_path",
|
||||||
|
doc_metadata.get("source", source_path),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
doc_source_owner = doc_metadata.get(
|
||||||
|
"owner", get_file_owner_from_path(doc_source_path)
|
||||||
|
)
|
||||||
|
doc_source_size = doc_metadata.get("size", get_source_size(doc_source_path))
|
||||||
|
page_content = str(doc.get("page_content"))
|
||||||
|
page_content_size = calculate_content_size(page_content)
|
||||||
|
source_aggregate_size += page_content_size
|
||||||
|
doc_id = doc.get("pb_id", None) or 0
|
||||||
|
docs.append(
|
||||||
|
{
|
||||||
|
"doc": page_content,
|
||||||
|
"source_path": doc_source_path,
|
||||||
|
"pb_id": doc_id,
|
||||||
|
"last_modified": doc.get("metadata", {}).get("last_modified"),
|
||||||
|
"file_owner": doc_source_owner,
|
||||||
|
**(
|
||||||
|
{"authorized_identities": doc_authorized_identities}
|
||||||
|
if doc_authorized_identities
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"source_path_size": doc_source_size}
|
||||||
|
if doc_source_size is not None
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return docs, source_aggregate_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_doc_data(docs: List[dict], classified_docs: dict) -> None:
|
||||||
|
"""
|
||||||
|
Update the document data with classified information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
docs (List[dict]): List of document data to be updated.
|
||||||
|
classified_docs (dict): The dictionary containing classified documents.
|
||||||
|
"""
|
||||||
|
for doc_data in docs:
|
||||||
|
classified_data = classified_docs.get(doc_data["pb_id"], {})
|
||||||
|
# Update the document data with classified information
|
||||||
|
doc_data.update(
|
||||||
|
{
|
||||||
|
"pb_checksum": classified_data.get("pb_checksum"),
|
||||||
|
"loader_source_path": classified_data.get("loader_source_path"),
|
||||||
|
"entities": classified_data.get("entities", {}),
|
||||||
|
"topics": classified_data.get("topics", {}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# Remove the document content
|
||||||
|
doc_data.pop("doc")
|
||||||
|
@ -144,4 +144,5 @@ def test_pebblo_safe_loader_api_key() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert loader.api_key == api_key
|
assert loader.pb_client.api_key == api_key
|
||||||
|
assert loader.pb_client.classifier_location == "local"
|
||||||
|
Loading…
Reference in New Issue
Block a user