mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
community[minor]: Added classification_location parameter in PebbloSafeLoader. (#22565)
Description: Add classifier_location feature flag. This flag enables Pebblo to decide the classifier location, local or pebblo-cloud. Unit Tests: N/A Documentation: N/A --------- Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
This commit is contained in:
@@ -46,6 +46,8 @@ class PebbloSafeLoader(BaseLoader):
|
||||
api_key: Optional[str] = None,
|
||||
load_semantic: bool = False,
|
||||
classifier_url: Optional[str] = None,
|
||||
*,
|
||||
classifier_location: str = "local",
|
||||
):
|
||||
if not name or not isinstance(name, str):
|
||||
raise NameError("Must specify a valid name.")
|
||||
@@ -65,6 +67,7 @@ class PebbloSafeLoader(BaseLoader):
|
||||
self.source_path_size = self.get_source_size(self.source_path)
|
||||
self.source_aggregate_size = 0
|
||||
self.classifier_url = classifier_url or CLASSIFIER_URL
|
||||
self.classifier_location = classifier_location
|
||||
self.loader_details = {
|
||||
"loader": loader_name,
|
||||
"source_path": self.source_path,
|
||||
@@ -158,6 +161,7 @@ class PebbloSafeLoader(BaseLoader):
|
||||
PebbloSafeLoader.set_loader_sent()
|
||||
doc_content = [doc.dict() for doc in loaded_docs]
|
||||
docs = []
|
||||
classified_docs = []
|
||||
for doc in doc_content:
|
||||
doc_metadata = doc.get("metadata", {})
|
||||
doc_authorized_identities = doc_metadata.get("authorized_identities", [])
|
||||
@@ -204,6 +208,7 @@ class PebbloSafeLoader(BaseLoader):
|
||||
"loader_details": self.loader_details,
|
||||
"loading_end": "false",
|
||||
"source_owner": self.source_owner,
|
||||
"classifier_location": self.classifier_location,
|
||||
}
|
||||
if loading_end is True:
|
||||
payload["loading_end"] = "true"
|
||||
@@ -212,39 +217,46 @@ class PebbloSafeLoader(BaseLoader):
|
||||
"source_aggregate_size"
|
||||
] = self.source_aggregate_size
|
||||
payload = Doc(**payload).dict(exclude_unset=True)
|
||||
load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}"
|
||||
classified_docs = []
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
load_doc_url, headers=headers, json=payload, timeout=300
|
||||
)
|
||||
classified_docs = json.loads(pebblo_resp.text).get("docs", None)
|
||||
if pebblo_resp.status_code not in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
logger.warning(
|
||||
"Received unexpected HTTP response code: %s",
|
||||
pebblo_resp.status_code,
|
||||
)
|
||||
logger.debug(
|
||||
"send_loader_doc[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_loader_doc: local %s", e)
|
||||
if self.api_key:
|
||||
if not classified_docs:
|
||||
return classified_docs
|
||||
# Raw payload to be sent to classifier
|
||||
if self.classifier_location == "local":
|
||||
load_doc_url = f"{self.classifier_url}{LOADER_DOC_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
load_doc_url, headers=headers, json=payload, timeout=300
|
||||
)
|
||||
classified_docs = json.loads(pebblo_resp.text).get("docs", None)
|
||||
if pebblo_resp.status_code not in [
|
||||
HTTPStatus.OK,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
]:
|
||||
logger.warning(
|
||||
"Received unexpected HTTP response code: %s",
|
||||
pebblo_resp.status_code,
|
||||
)
|
||||
logger.debug(
|
||||
"send_loader_doc[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_loader_doc: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
if self.classifier_location == "local":
|
||||
payload["docs"] = classified_docs
|
||||
payload["classified"] = True
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}"
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{LOADER_DOC_URL}"
|
||||
try:
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
@@ -267,9 +279,10 @@ class PebbloSafeLoader(BaseLoader):
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_loader_doc: cloud %s", e)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending docs to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending docs to Pebblo cloud.")
|
||||
|
||||
if loading_end is True:
|
||||
PebbloSafeLoader.set_loader_sent()
|
||||
return classified_docs
|
||||
|
||||
@staticmethod
|
||||
@@ -298,45 +311,50 @@ class PebbloSafeLoader(BaseLoader):
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = self.app.dict(exclude_unset=True)
|
||||
app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
logger.debug(
|
||||
"send_discover[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloSafeLoader.set_discover_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
|
||||
# Raw discover payload to be sent to classifier
|
||||
if self.classifier_location == "local":
|
||||
app_discover_url = f"{self.classifier_url}{APP_DISCOVER_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
logger.debug(
|
||||
"send_discover[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloSafeLoader.set_discover_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received unexpected HTTP response code:\
|
||||
{pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
try:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
# If the pebblo_resp is None,
|
||||
# then the pebblo server version is not available
|
||||
if pebblo_resp:
|
||||
pebblo_server_version = json.loads(pebblo_resp.text).get(
|
||||
"pebblo_server_version"
|
||||
)
|
||||
payload.update(
|
||||
{
|
||||
"pebblo_server_version": pebblo_server_version,
|
||||
"pebblo_client_version": payload["plugin_version"],
|
||||
}
|
||||
)
|
||||
payload.pop("plugin_version")
|
||||
payload.update({"pebblo_server_version": pebblo_server_version})
|
||||
|
||||
payload.update({"pebblo_client_version": PLUGIN_VERSION})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
|
Reference in New Issue
Block a user