From 4ff2f4499e700748e560641ad14c828535f89b8e Mon Sep 17 00:00:00 2001 From: Rajendra Kadam Date: Thu, 22 Aug 2024 21:21:21 +0530 Subject: [PATCH] community: Refactor PebbloRetrievalQA (#25583) **Refactor PebbloRetrievalQA** - Created `APIWrapper` and moved API logic into it. - Created smaller functions/methods for better readability. - Properly read environment variables. - Removed unused code. - Updated models **Issue:** NA **Dependencies:** NA **tests**: NA --- .../chains/pebblo_retrieval/base.py | 358 +++--------------- .../chains/pebblo_retrieval/models.py | 10 +- .../chains/pebblo_retrieval/utilities.py | 340 ++++++++++++++++- 3 files changed, 391 insertions(+), 317 deletions(-) diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 61eb638e442..93314301b4d 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -5,12 +5,9 @@ against a vector database. import datetime import inspect -import json import logging -from http import HTTPStatus -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional -import requests # type: ignore from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain_core.callbacks import ( @@ -29,16 +26,14 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import ( from langchain_community.chains.pebblo_retrieval.models import ( App, AuthContext, - Qa, + ChainInfo, + Model, SemanticContext, + VectorDB, ) from langchain_community.chains.pebblo_retrieval.utilities import ( - APP_DISCOVER_URL, - CLASSIFIER_URL, - PEBBLO_CLOUD_URL, PLUGIN_VERSION, - PROMPT_GOV_URL, - PROMPT_URL, + PebbloRetrievalAPIWrapper, get_runtime, ) @@ -72,16 +67,18 @@ class PebbloRetrievalQA(Chain): """Description of app.""" api_key: Optional[str] = None #: :meta private: """Pebblo cloud API key for app.""" - classifier_url: str = CLASSIFIER_URL #: :meta private: + classifier_url: Optional[str] = None #: :meta private: """Classifier endpoint.""" classifier_location: str = "local" #: :meta private: """Classifier location. It could be either of 'local' or 'pebblo-cloud'.""" _discover_sent: bool = False #: :meta private: """Flag to check if discover payload has been sent.""" - _prompt_sent: bool = False #: :meta private: - """Flag to check if prompt payload has been sent.""" enable_prompt_gov: bool = True #: :meta private: """Flag to check if prompt governance is enabled or not""" + pb_client: PebbloRetrievalAPIWrapper = Field( + default_factory=PebbloRetrievalAPIWrapper + ) + """Pebblo Retrieval API client""" def _call( self, @@ -100,12 +97,11 @@ class PebbloRetrievalQA(Chain): answer, docs = res['result'], res['source_documents'] """ prompt_time = datetime.datetime.now().isoformat() - PebbloRetrievalQA.set_prompt_sent(value=False) _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] - auth_context = inputs.get(self.auth_context_key, {}) - semantic_context = inputs.get(self.semantic_context_key, {}) - _, prompt_entities = self._check_prompt_validity(question) + auth_context = inputs.get(self.auth_context_key) + semantic_context = inputs.get(self.semantic_context_key) + _, prompt_entities = self.pb_client.check_prompt_validity(question) accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters @@ -120,43 +116,17 @@ class PebbloRetrievalQA(Chain): input_documents=docs, question=question, callbacks=_run_manager.get_child() ) - qa = { - "name": self.app_name, - "context": [ - { - "retrieved_from": doc.metadata.get( - "full_path", doc.metadata.get("source") - ), - "doc": doc.page_content, - "vector_db": self.retriever.vectorstore.__class__.__name__, - **( - {"pb_checksum": doc.metadata.get("pb_checksum")} - if doc.metadata.get("pb_checksum") - else {} - ), - } - for doc in docs - if isinstance(doc, Document) - ], - "prompt": { - "data": question, - "entities": prompt_entities.get("entities", {}), - "entityCount": prompt_entities.get("entityCount", 0), - "prompt_gov_enabled": self.enable_prompt_gov, - }, - "response": { - "data": answer, - }, - "prompt_time": prompt_time, - "user": auth_context.user_id if auth_context else "unknown", - "user_identities": auth_context.user_auth - if auth_context and hasattr(auth_context, "user_auth") - else [], - "classifier_location": self.classifier_location, - } - - qa_payload = Qa(**qa) - self._send_prompt(qa_payload) + self.pb_client.send_prompt( + self.app_name, + self.retriever, + question, + answer, + auth_context, + docs, + prompt_entities, + prompt_time, + self.enable_prompt_gov, + ) if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} @@ -187,7 +157,7 @@ class PebbloRetrievalQA(Chain): "run_manager" in inspect.signature(self._aget_docs).parameters ) - _, prompt_entities = self._check_prompt_validity(question) + _, prompt_entities = self.pb_client.check_prompt_validity(question) if accepts_run_manager: docs = await self._aget_docs( @@ -243,7 +213,7 @@ class PebbloRetrievalQA(Chain): chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, api_key: Optional[str] = None, - classifier_url: str = CLASSIFIER_URL, + classifier_url: Optional[str] = None, classifier_location: str = "local", **kwargs: Any, ) -> "PebbloRetrievalQA": @@ -263,14 +233,14 @@ class PebbloRetrievalQA(Chain): llm=llm, **kwargs, ) - - PebbloRetrievalQA._send_discover( - app, + # initialize Pebblo API client + pb_client = PebbloRetrievalAPIWrapper( api_key=api_key, - classifier_url=classifier_url, classifier_location=classifier_location, + classifier_url=classifier_url, ) - + # send app discovery request + pb_client.send_app_discover(app) return cls( combine_documents_chain=combine_documents_chain, app_name=app_name, @@ -279,6 +249,7 @@ class PebbloRetrievalQA(Chain): api_key=api_key, classifier_url=classifier_url, classifier_location=classifier_location, + pb_client=pb_client, **kwargs, ) @@ -346,259 +317,36 @@ class PebbloRetrievalQA(Chain): ) return app - @staticmethod - def _send_discover( - app: App, - api_key: Optional[str], - classifier_url: str, - classifier_location: str, - ) -> None: # type: ignore - """Send app discovery payload to pebblo-server. Internal method.""" - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - } - payload = app.dict(exclude_unset=True) - if classifier_location == "local": - app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}" - try: - pebblo_resp = requests.post( - app_discover_url, headers=headers, json=payload, timeout=20 - ) - logger.debug("discover-payload: %s", payload) - logger.debug( - "send_discover[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str( - len( - pebblo_resp.request.body if pebblo_resp.request.body else [] - ) - ), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: - PebbloRetrievalQA.set_discover_sent() - else: - logger.warning( - "Received unexpected HTTP response code:" - + f"{pebblo_resp.status_code}" - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) - - if api_key: - try: - headers.update({"x-api-key": api_key}) - pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}" - pebblo_cloud_response = requests.post( - pebblo_cloud_url, headers=headers, json=payload, timeout=20 - ) - - logger.debug( - "send_discover[cloud]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_cloud_response.request.url, - str(pebblo_cloud_response.request.body), - str( - len( - pebblo_cloud_response.request.body - if pebblo_cloud_response.request.body - else [] - ) - ), - str(pebblo_cloud_response.status_code), - pebblo_cloud_response.json(), - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach Pebblo cloud server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: cloud %s", e) - @classmethod def set_discover_sent(cls) -> None: cls._discover_sent = True @classmethod - def set_prompt_sent(cls, value: bool = True) -> None: - cls._prompt_sent = value - - def _send_prompt(self, qa_payload: Qa) -> None: - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - } - app_discover_url = f"{self.classifier_url}{PROMPT_URL}" - pebblo_resp = None - payload = qa_payload.dict(exclude_unset=True) - if self.classifier_location == "local": - try: - pebblo_resp = requests.post( - app_discover_url, - headers=headers, - json=payload, - timeout=20, - ) - logger.debug("prompt-payload: %s", payload) - logger.debug( - "send_prompt[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str( - len( - pebblo_resp.request.body if pebblo_resp.request.body else [] - ) - ), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: - PebbloRetrievalQA.set_prompt_sent() - else: - logger.warning( - "Received unexpected HTTP response code:" - + f"{pebblo_resp.status_code}" - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) - - # If classifier location is local, then response, context and prompt - # should be fetched from pebblo_resp and replaced in payload. - if self.api_key: - if self.classifier_location == "local": - if pebblo_resp: - resp = json.loads(pebblo_resp.text) - if resp: - payload["response"].update( - resp.get("retrieval_data", {}).get("response", {}) - ) - payload["response"].pop("data") - payload["prompt"].update( - resp.get("retrieval_data", {}).get("prompt", {}) - ) - payload["prompt"].pop("data") - context = payload["context"] - for context_data in context: - context_data.pop("doc") - payload["context"] = context - else: - payload["response"] = {} - payload["prompt"] = {} - payload["context"] = [] - headers.update({"x-api-key": self.api_key}) - pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}" - try: - pebblo_cloud_response = requests.post( - pebblo_cloud_url, - headers=headers, - json=payload, - timeout=20, - ) - - logger.debug( - "send_prompt[cloud]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_cloud_response.request.url, - str(pebblo_cloud_response.request.body), - str( - len( - pebblo_cloud_response.request.body - if pebblo_cloud_response.request.body - else [] - ) - ), - str(pebblo_cloud_response.status_code), - pebblo_cloud_response.json(), - ) - except requests.exceptions.RequestException: - logger.warning("Unable to reach Pebblo cloud server.") - except Exception as e: - logger.warning("An Exception caught in _send_prompt: cloud %s", e) - elif self.classifier_location == "pebblo-cloud": - logger.warning("API key is missing for sending prompt to Pebblo cloud.") - raise NameError("API key is missing for sending prompt to Pebblo cloud.") - - def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]: + def get_chain_details( + cls, llm: BaseLanguageModel, **kwargs: Any + ) -> List[ChainInfo]: """ - Check the validity of the given prompt using a remote classification service. - - This method sends a prompt to a remote classifier service and return entities - present in prompt or not. + Get chain details. Args: - question (str): The prompt question to be validated. + llm (BaseLanguageModel): Language model instance. + **kwargs: Additional keyword arguments. Returns: - bool: True if the prompt is valid (does not contain deny list entities), - False otherwise. - dict: The entities present in the prompt + List[ChainInfo]: Chain details. """ - - headers = { - "Accept": "application/json", - "Content-Type": "application/json", - } - prompt_payload = {"prompt": question} - is_valid_prompt: bool = True - prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}" - pebblo_resp = None - prompt_entities: dict = {"entities": {}, "entityCount": 0} - if self.classifier_location == "local": - try: - pebblo_resp = requests.post( - prompt_gov_api_url, - headers=headers, - json=prompt_payload, - timeout=20, - ) - - logger.debug("prompt-payload: %s", prompt_payload) - logger.debug( - "send_prompt[local]: request url %s, body %s len %s\ - response status %s body %s", - pebblo_resp.request.url, - str(pebblo_resp.request.body), - str( - len( - pebblo_resp.request.body if pebblo_resp.request.body else [] - ) - ), - str(pebblo_resp.status_code), - pebblo_resp.json(), - ) - logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}") - prompt_entities["entities"] = pebblo_resp.json().get("entities", {}) - prompt_entities["entityCount"] = pebblo_resp.json().get( - "entityCount", 0 - ) - - except requests.exceptions.RequestException: - logger.warning("Unable to reach pebblo server.") - except Exception as e: - logger.warning("An Exception caught in _send_discover: local %s", e) - return is_valid_prompt, prompt_entities - - @classmethod - def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore llm_dict = llm.__dict__ - chain = [ - { - "name": cls.__name__, - "model": { - "name": llm_dict.get("model_name", llm_dict.get("model")), - "vendor": llm.__class__.__name__, - }, - "vector_dbs": [ - { - "name": kwargs["retriever"].vectorstore.__class__.__name__, - "embedding_model": str( + chains = [ + ChainInfo( + name=cls.__name__, + model=Model( + name=llm_dict.get("model_name", llm_dict.get("model")), + vendor=llm.__class__.__name__, + ), + vector_dbs=[ + VectorDB( + name=kwargs["retriever"].vectorstore.__class__.__name__, + embedding_model=str( kwargs["retriever"].vectorstore._embeddings.model ) if hasattr(kwargs["retriever"].vectorstore, "_embeddings") @@ -607,8 +355,8 @@ class PebbloRetrievalQA(Chain): if hasattr(kwargs["retriever"].vectorstore, "_embedding") else None ), - } + ) ], - }, + ), ] - return chain + return chains diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index e4fd7c64963..d5693404214 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -109,7 +109,7 @@ class VectorDB(BaseModel): embedding_model: Optional[str] = None -class Chains(BaseModel): +class ChainInfo(BaseModel): name: str model: Optional[Model] vector_dbs: Optional[List[VectorDB]] @@ -121,7 +121,7 @@ class App(BaseModel): description: Optional[str] runtime: Runtime framework: Framework - chains: List[Chains] + chains: List[ChainInfo] plugin_version: str @@ -134,9 +134,9 @@ class Context(BaseModel): class Prompt(BaseModel): data: Optional[Union[list, str]] - entityCount: Optional[int] - entities: Optional[dict] - prompt_gov_enabled: Optional[bool] + entityCount: Optional[int] = None + entities: Optional[dict] = None + prompt_gov_enabled: Optional[bool] = None class Qa(BaseModel): diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py index 86218ad07b0..568fc560c0f 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -1,22 +1,43 @@ +import json import logging import os import platform -from typing import Tuple +from enum import Enum +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Tuple +from langchain_core.documents import Document from langchain_core.env import get_runtime_environment +from langchain_core.pydantic_v1 import BaseModel +from langchain_core.utils import get_from_dict_or_env +from langchain_core.vectorstores import VectorStoreRetriever +from requests import Response, request +from requests.exceptions import RequestException -from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime +from langchain_community.chains.pebblo_retrieval.models import ( + App, + AuthContext, + Context, + Framework, + Prompt, + Qa, + Runtime, +) logger = logging.getLogger(__name__) PLUGIN_VERSION = "0.1.1" -CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000") -PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") +_DEFAULT_CLASSIFIER_URL = "http://localhost:8000" +_DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai" -PROMPT_URL = "/v1/prompt" -PROMPT_GOV_URL = "/v1/prompt/governance" -APP_DISCOVER_URL = "/v1/app/discover" + +class Routes(str, Enum): + """Routes available for the Pebblo API as enumerator.""" + + retrieval_app_discover = "/v1/app/discover" + prompt = "/v1/prompt" + prompt_governance = "/v1/prompt/governance" def get_runtime() -> Tuple[Framework, Runtime]: @@ -64,3 +85,308 @@ def get_ip() -> str: except Exception: public_ip = socket.gethostbyname("localhost") return public_ip + + +class PebbloRetrievalAPIWrapper(BaseModel): + """Wrapper for Pebblo Retrieval API.""" + + api_key: Optional[str] # Use SecretStr + """API key for Pebblo Cloud""" + classifier_location: str = "local" + """Location of the classifier, local or cloud. Defaults to 'local'""" + classifier_url: Optional[str] + """URL of the Pebblo Classifier""" + cloud_url: Optional[str] + """URL of the Pebblo Cloud""" + + def __init__(self, **kwargs: Any): + """Validate that api key in environment.""" + kwargs["api_key"] = get_from_dict_or_env( + kwargs, "api_key", "PEBBLO_API_KEY", "" + ) + kwargs["classifier_url"] = get_from_dict_or_env( + kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL + ) + kwargs["cloud_url"] = get_from_dict_or_env( + kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL + ) + super().__init__(**kwargs) + + def send_app_discover(self, app: App) -> None: + """ + Send app discovery request to Pebblo server & cloud. + + Args: + app (App): App instance to be discovered. + """ + pebblo_resp = None + payload = app.dict(exclude_unset=True) + + if self.classifier_location == "local": + # Send app details to local classifier + headers = self._make_headers() + app_discover_url = f"{self.classifier_url}{Routes.retrieval_app_discover}" + pebblo_resp = self.make_request("POST", app_discover_url, headers, payload) + + if self.api_key: + # Send app details to Pebblo cloud if api_key is present + headers = self._make_headers(cloud_request=True) + if pebblo_resp: + pebblo_server_version = json.loads(pebblo_resp.text).get( + "pebblo_server_version" + ) + payload.update({"pebblo_server_version": pebblo_server_version}) + + payload.update({"pebblo_client_version": PLUGIN_VERSION}) + pebblo_cloud_url = f"{self.cloud_url}{Routes.retrieval_app_discover}" + _ = self.make_request("POST", pebblo_cloud_url, headers, payload) + + def send_prompt( + self, + app_name: str, + retriever: VectorStoreRetriever, + question: str, + answer: str, + auth_context: Optional[AuthContext], + docs: List[Document], + prompt_entities: Dict[str, Any], + prompt_time: str, + prompt_gov_enabled: bool = False, + ) -> None: + """ + Send prompt to Pebblo server for classification. + Then send prompt to Daxa cloud(If api_key is present). + + Args: + app_name (str): Name of the app. + retriever (VectorStoreRetriever): Retriever instance. + question (str): Question asked in the prompt. + answer (str): Answer generated by the model. + auth_context (Optional[AuthContext]): Authentication context. + docs (List[Document]): List of documents retrieved. + prompt_entities (Dict[str, Any]): Entities present in the prompt. + prompt_time (str): Time when the prompt was generated. + prompt_gov_enabled (bool): Whether prompt governance is enabled. + """ + pebblo_resp = None + payload = self.build_prompt_qa_payload( + app_name, + retriever, + question, + answer, + auth_context, + docs, + prompt_entities, + prompt_time, + prompt_gov_enabled, + ) + + if self.classifier_location == "local": + # Send prompt to local classifier + headers = self._make_headers() + prompt_url = f"{self.classifier_url}{Routes.prompt}" + pebblo_resp = self.make_request("POST", prompt_url, headers, payload) + + if self.api_key: + # Send prompt to Pebblo cloud if api_key is present + if self.classifier_location == "local": + # If classifier location is local, then response, context and prompt + # should be fetched from pebblo_resp and replaced in payload. + pebblo_resp = pebblo_resp.json() if pebblo_resp else None + self.update_cloud_payload(payload, pebblo_resp) + + headers = self._make_headers(cloud_request=True) + pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt}" + _ = self.make_request("POST", pebblo_cloud_prompt_url, headers, payload) + elif self.classifier_location == "pebblo-cloud": + logger.warning("API key is missing for sending prompt to Pebblo cloud.") + raise NameError("API key is missing for sending prompt to Pebblo cloud.") + + def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]: + """ + Check the validity of the given prompt using a remote classification service. + + This method sends a prompt to a remote classifier service and return entities + present in prompt or not. + + Args: + question (str): The prompt question to be validated. + + Returns: + bool: True if the prompt is valid (does not contain deny list entities), + False otherwise. + dict: The entities present in the prompt + """ + prompt_payload = {"prompt": question} + prompt_entities: dict = {"entities": {}, "entityCount": 0} + is_valid_prompt: bool = True + if self.classifier_location == "local": + headers = self._make_headers() + prompt_gov_api_url = f"{self.classifier_url}{Routes.prompt_governance}" + pebblo_resp = self.make_request( + "POST", prompt_gov_api_url, headers, prompt_payload + ) + if pebblo_resp: + logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}") + prompt_entities["entities"] = pebblo_resp.json().get("entities", {}) + prompt_entities["entityCount"] = pebblo_resp.json().get( + "entityCount", 0 + ) + return is_valid_prompt, prompt_entities + + def _make_headers(self, cloud_request: bool = False) -> dict: + """ + Generate headers for the request. + + args: + cloud_request (bool): flag indicating whether the request is for Pebblo + cloud. + returns: + dict: Headers for the request. + + """ + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if cloud_request: + # Add API key for Pebblo cloud request + if self.api_key: + headers.update({"x-api-key": self.api_key}) + else: + logger.warning("API key is missing for Pebblo cloud request.") + return headers + + @staticmethod + def make_request( + method: str, + url: str, + headers: dict, + payload: Optional[dict] = None, + timeout: int = 20, + ) -> Optional[Response]: + """ + Make a request to the Pebblo server/cloud API. + + Args: + method (str): HTTP method (GET, POST, PUT, DELETE, etc.). + url (str): URL for the request. + headers (dict): Headers for the request. + payload (Optional[dict]): Payload for the request (for POST, PUT, etc.). + timeout (int): Timeout for the request in seconds. + + Returns: + Optional[Response]: Response object if the request is successful. + """ + try: + response = request( + method=method, url=url, headers=headers, json=payload, timeout=timeout + ) + logger.debug( + "Request: method %s, url %s, len %s response status %s", + method, + response.request.url, + str(len(response.request.body if response.request.body else [])), + str(response.status_code), + ) + + if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR: + logger.warning(f"Pebblo Server: Error {response.status_code}") + elif response.status_code >= HTTPStatus.BAD_REQUEST: + logger.warning(f"Pebblo received an invalid payload: {response.text}") + elif response.status_code != HTTPStatus.OK: + logger.warning( + f"Pebblo returned an unexpected response code: " + f"{response.status_code}" + ) + + return response + except RequestException: + logger.warning("Unable to reach server %s", url) + except Exception as e: + logger.warning("An Exception caught in make_request: %s", e) + return None + + @staticmethod + def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None: + """ + Update the payload with response, prompt and context from Pebblo response. + + Args: + payload (dict): Payload to be updated. + pebblo_resp (Optional[dict]): Response from Pebblo server. + """ + if pebblo_resp: + # Update response, prompt and context from pebblo response + response = payload.get("response", {}) + response.update(pebblo_resp.get("retrieval_data", {}).get("response", {})) + response.pop("data", None) + prompt = payload.get("prompt", {}) + prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {})) + prompt.pop("data", None) + context = payload.get("context", []) + for context_data in context: + context_data.pop("doc", None) + else: + payload["response"] = {} + payload["prompt"] = {} + payload["context"] = [] + + def build_prompt_qa_payload( + self, + app_name: str, + retriever: VectorStoreRetriever, + question: str, + answer: str, + auth_context: Optional[AuthContext], + docs: List[Document], + prompt_entities: Dict[str, Any], + prompt_time: str, + prompt_gov_enabled: bool = False, + ) -> dict: + """ + Build the QA payload for the prompt. + + Args: + app_name (str): Name of the app. + retriever (VectorStoreRetriever): Retriever instance. + question (str): Question asked in the prompt. + answer (str): Answer generated by the model. + auth_context (Optional[AuthContext]): Authentication context. + docs (List[Document]): List of documents retrieved. + prompt_entities (Dict[str, Any]): Entities present in the prompt. + prompt_time (str): Time when the prompt was generated. + prompt_gov_enabled (bool): Whether prompt governance is enabled. + + Returns: + dict: The QA payload for the prompt. + """ + qa = Qa( + name=app_name, + context=[ + Context( + retrieved_from=doc.metadata.get( + "full_path", doc.metadata.get("source") + ), + doc=doc.page_content, + vector_db=retriever.vectorstore.__class__.__name__, + pb_checksum=doc.metadata.get("pb_checksum"), + ) + for doc in docs + if isinstance(doc, Document) + ], + prompt=Prompt( + data=question, + entities=prompt_entities.get("entities", {}), + entityCount=prompt_entities.get("entityCount", 0), + prompt_gov_enabled=prompt_gov_enabled, + ), + response=Prompt(data=answer), + prompt_time=prompt_time, + user=auth_context.user_id if auth_context else "unknown", + user_identities=auth_context.user_auth + if auth_context and hasattr(auth_context, "user_auth") + else [], + classifier_location=self.classifier_location, + ) + return qa.dict(exclude_unset=True)