diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index 93314301b4d..a29bda567c2 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -149,6 +149,7 @@ class PebbloRetrievalQA(Chain): res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + prompt_time = datetime.datetime.now().isoformat() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] auth_context = inputs.get(self.auth_context_key) @@ -157,7 +158,7 @@ class PebbloRetrievalQA(Chain): "run_manager" in inspect.signature(self._aget_docs).parameters ) - _, prompt_entities = self.pb_client.check_prompt_validity(question) + _, prompt_entities = await self.pb_client.acheck_prompt_validity(question) if accepts_run_manager: docs = await self._aget_docs( @@ -169,6 +170,18 @@ class PebbloRetrievalQA(Chain): input_documents=docs, question=question, callbacks=_run_manager.get_child() ) + await self.pb_client.asend_prompt( + self.app_name, + self.retriever, + question, + answer, + auth_context, + docs, + prompt_entities, + prompt_time, + self.enable_prompt_gov, + ) + if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py index 568fc560c0f..e7e6772890b 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -6,6 +6,8 @@ from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple +import aiohttp +from aiohttp import ClientTimeout from langchain_core.documents import Document from langchain_core.env import get_runtime_environment from langchain_core.pydantic_v1 import BaseModel @@ -202,6 +204,68 @@ class PebbloRetrievalAPIWrapper(BaseModel): logger.warning("API key is missing for sending prompt to Pebblo cloud.") raise NameError("API key is missing for sending prompt to Pebblo cloud.") + async def asend_prompt( + self, + app_name: str, + retriever: VectorStoreRetriever, + question: str, + answer: str, + auth_context: Optional[AuthContext], + docs: List[Document], + prompt_entities: Dict[str, Any], + prompt_time: str, + prompt_gov_enabled: bool = False, + ) -> None: + """ + Send prompt to Pebblo server for classification. + Then send prompt to Daxa cloud(If api_key is present). + + Args: + app_name (str): Name of the app. + retriever (VectorStoreRetriever): Retriever instance. + question (str): Question asked in the prompt. + answer (str): Answer generated by the model. + auth_context (Optional[AuthContext]): Authentication context. + docs (List[Document]): List of documents retrieved. + prompt_entities (Dict[str, Any]): Entities present in the prompt. + prompt_time (str): Time when the prompt was generated. + prompt_gov_enabled (bool): Whether prompt governance is enabled. + """ + pebblo_resp = None + payload = self.build_prompt_qa_payload( + app_name, + retriever, + question, + answer, + auth_context, + docs, + prompt_entities, + prompt_time, + prompt_gov_enabled, + ) + + if self.classifier_location == "local": + # Send prompt to local classifier + headers = self._make_headers() + prompt_url = f"{self.classifier_url}{Routes.prompt.value}" + pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload) + + if self.api_key: + # Send prompt to Pebblo cloud if api_key is present + if self.classifier_location == "local": + # If classifier location is local, then response, context and prompt + # should be fetched from pebblo_resp and replaced in payload. + self.update_cloud_payload(payload, pebblo_resp) + + headers = self._make_headers(cloud_request=True) + pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}" + _ = await self.amake_request( + "POST", pebblo_cloud_prompt_url, headers, payload + ) + elif self.classifier_location == "pebblo-cloud": + logger.warning("API key is missing for sending prompt to Pebblo cloud.") + raise NameError("API key is missing for sending prompt to Pebblo cloud.") + def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]: """ Check the validity of the given prompt using a remote classification service. @@ -227,13 +291,45 @@ class PebbloRetrievalAPIWrapper(BaseModel): "POST", prompt_gov_api_url, headers, prompt_payload ) if pebblo_resp: - logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}") prompt_entities["entities"] = pebblo_resp.json().get("entities", {}) prompt_entities["entityCount"] = pebblo_resp.json().get( "entityCount", 0 ) return is_valid_prompt, prompt_entities + async def acheck_prompt_validity( + self, question: str + ) -> Tuple[bool, Dict[str, Any]]: + """ + Check the validity of the given prompt using a remote classification service. + + This method sends a prompt to a remote classifier service and return entities + present in prompt or not. + + Args: + question (str): The prompt question to be validated. + + Returns: + bool: True if the prompt is valid (does not contain deny list entities), + False otherwise. + dict: The entities present in the prompt + """ + prompt_payload = {"prompt": question} + prompt_entities: dict = {"entities": {}, "entityCount": 0} + is_valid_prompt: bool = True + if self.classifier_location == "local": + headers = self._make_headers() + prompt_gov_api_url = ( + f"{self.classifier_url}{Routes.prompt_governance.value}" + ) + pebblo_resp = await self.amake_request( + "POST", prompt_gov_api_url, headers, prompt_payload + ) + if pebblo_resp: + prompt_entities["entities"] = pebblo_resp.get("entities", {}) + prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0) + return is_valid_prompt, prompt_entities + def _make_headers(self, cloud_request: bool = False) -> dict: """ Generate headers for the request. @@ -332,6 +428,56 @@ class PebbloRetrievalAPIWrapper(BaseModel): payload["prompt"] = {} payload["context"] = [] + @staticmethod + async def amake_request( + method: str, + url: str, + headers: dict, + payload: Optional[dict] = None, + timeout: int = 20, + ) -> Any: + """ + Make a async request to the Pebblo server/cloud API. + + Args: + method (str): HTTP method (GET, POST, PUT, DELETE, etc.). + url (str): URL for the request. + headers (dict): Headers for the request. + payload (Optional[dict]): Payload for the request (for POST, PUT, etc.). + timeout (int): Timeout for the request in seconds. + + Returns: + Any: Response json if the request is successful. + """ + try: + client_timeout = ClientTimeout(total=timeout) + async with aiohttp.ClientSession() as asession: + async with asession.request( + method=method, + url=url, + json=payload, + headers=headers, + timeout=client_timeout, + ) as response: + if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR: + logger.warning(f"Pebblo Server: Error {response.status}") + elif response.status >= HTTPStatus.BAD_REQUEST: + logger.warning( + f"Pebblo received an invalid payload: " f"{response.text}" + ) + elif response.status != HTTPStatus.OK: + logger.warning( + f"Pebblo returned an unexpected response code: " + f"{response.status}" + ) + response_json = await response.json() + return response_json + except RequestException: + logger.warning("Unable to reach server %s", url) + except Exception as e: + logger.warning("An Exception caught in amake_request: %s", e) + return None + def build_prompt_qa_payload( self, app_name: str,