community: [PebbloRetrievalQA] Implemented Async support for prompt APIs (#25748)

- **Description:** PebbloRetrievalQA: Implemented Async support for
prompt APIs (classification and governance)
- **Issue:** NA
- **Dependencies:** NA
This commit is contained in:
Rajendra Kadam 2024-08-26 16:57:05 +05:30 committed by GitHub
parent 6703d795c5
commit 58a98c7d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 161 additions and 2 deletions

View File

@ -149,6 +149,7 @@ class PebbloRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'}) res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents'] answer, docs = res['result'], res['source_documents']
""" """
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key] question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key) auth_context = inputs.get(self.auth_context_key)
@ -157,7 +158,7 @@ class PebbloRetrievalQA(Chain):
"run_manager" in inspect.signature(self._aget_docs).parameters "run_manager" in inspect.signature(self._aget_docs).parameters
) )
_, prompt_entities = self.pb_client.check_prompt_validity(question) _, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
if accepts_run_manager: if accepts_run_manager:
docs = await self._aget_docs( docs = await self._aget_docs(
@ -169,6 +170,18 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child() input_documents=docs, question=question, callbacks=_run_manager.get_child()
) )
await self.pb_client.asend_prompt(
self.app_name,
self.retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
self.enable_prompt_gov,
)
if self.return_source_documents: if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
else: else:

View File

@ -6,6 +6,8 @@ from enum import Enum
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import aiohttp
from aiohttp import ClientTimeout
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment from langchain_core.env import get_runtime_environment
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -202,6 +204,68 @@ class PebbloRetrievalAPIWrapper(BaseModel):
logger.warning("API key is missing for sending prompt to Pebblo cloud.") logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.") raise NameError("API key is missing for sending prompt to Pebblo cloud.")
async def asend_prompt(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> None:
"""
Send prompt to Pebblo server for classification.
Then send prompt to Daxa cloud(If api_key is present).
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
"""
pebblo_resp = None
payload = self.build_prompt_qa_payload(
app_name,
retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
prompt_gov_enabled,
)
if self.classifier_location == "local":
# Send prompt to local classifier
headers = self._make_headers()
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload)
if self.api_key:
# Send prompt to Pebblo cloud if api_key is present
if self.classifier_location == "local":
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
self.update_cloud_payload(payload, pebblo_resp)
headers = self._make_headers(cloud_request=True)
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
_ = await self.amake_request(
"POST", pebblo_cloud_prompt_url, headers, payload
)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]: def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
""" """
Check the validity of the given prompt using a remote classification service. Check the validity of the given prompt using a remote classification service.
@ -227,13 +291,45 @@ class PebbloRetrievalAPIWrapper(BaseModel):
"POST", prompt_gov_api_url, headers, prompt_payload "POST", prompt_gov_api_url, headers, prompt_payload
) )
if pebblo_resp: if pebblo_resp:
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {}) prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get( prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0 "entityCount", 0
) )
return is_valid_prompt, prompt_entities return is_valid_prompt, prompt_entities
async def acheck_prompt_validity(
self, question: str
) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
prompt_payload = {"prompt": question}
prompt_entities: dict = {"entities": {}, "entityCount": 0}
is_valid_prompt: bool = True
if self.classifier_location == "local":
headers = self._make_headers()
prompt_gov_api_url = (
f"{self.classifier_url}{Routes.prompt_governance.value}"
)
pebblo_resp = await self.amake_request(
"POST", prompt_gov_api_url, headers, prompt_payload
)
if pebblo_resp:
prompt_entities["entities"] = pebblo_resp.get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0)
return is_valid_prompt, prompt_entities
def _make_headers(self, cloud_request: bool = False) -> dict: def _make_headers(self, cloud_request: bool = False) -> dict:
""" """
Generate headers for the request. Generate headers for the request.
@ -332,6 +428,56 @@ class PebbloRetrievalAPIWrapper(BaseModel):
payload["prompt"] = {} payload["prompt"] = {}
payload["context"] = [] payload["context"] = []
@staticmethod
async def amake_request(
method: str,
url: str,
headers: dict,
payload: Optional[dict] = None,
timeout: int = 20,
) -> Any:
"""
Make a async request to the Pebblo server/cloud API.
Args:
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
url (str): URL for the request.
headers (dict): Headers for the request.
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
timeout (int): Timeout for the request in seconds.
Returns:
Any: Response json if the request is successful.
"""
try:
client_timeout = ClientTimeout(total=timeout)
async with aiohttp.ClientSession() as asession:
async with asession.request(
method=method,
url=url,
json=payload,
headers=headers,
timeout=client_timeout,
) as response:
if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
logger.warning(f"Pebblo Server: Error {response.status}")
elif response.status >= HTTPStatus.BAD_REQUEST:
logger.warning(
f"Pebblo received an invalid payload: " f"{response.text}"
)
elif response.status != HTTPStatus.OK:
logger.warning(
f"Pebblo returned an unexpected response code: "
f"{response.status}"
)
response_json = await response.json()
return response_json
except RequestException:
logger.warning("Unable to reach server %s", url)
except Exception as e:
logger.warning("An Exception caught in amake_request: %s", e)
return None
def build_prompt_qa_payload( def build_prompt_qa_payload(
self, self,
app_name: str, app_name: str,