community: [PebbloRetrievalQA] Implemented Async support for prompt APIs (#25748)

- **Description:** PebbloRetrievalQA: Implemented Async support for
prompt APIs (classification and governance)
- **Issue:** NA
- **Dependencies:** NA
This commit is contained in:
Rajendra Kadam 2024-08-26 16:57:05 +05:30 committed by GitHub
parent 6703d795c5
commit 58a98c7d8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 161 additions and 2 deletions

View File

@ -149,6 +149,7 @@ class PebbloRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
@ -157,7 +158,7 @@ class PebbloRetrievalQA(Chain):
"run_manager" in inspect.signature(self._aget_docs).parameters
)
_, prompt_entities = self.pb_client.check_prompt_validity(question)
_, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
if accepts_run_manager:
docs = await self._aget_docs(
@ -169,6 +170,18 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
await self.pb_client.asend_prompt(
self.app_name,
self.retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
self.enable_prompt_gov,
)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:

View File

@ -6,6 +6,8 @@ from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
from aiohttp import ClientTimeout
from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment
from langchain_core.pydantic_v1 import BaseModel
@ -202,6 +204,68 @@ class PebbloRetrievalAPIWrapper(BaseModel):
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
async def asend_prompt(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> None:
"""
Send prompt to Pebblo server for classification.
Then send prompt to Daxa cloud(If api_key is present).
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
"""
pebblo_resp = None
payload = self.build_prompt_qa_payload(
app_name,
retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
prompt_gov_enabled,
)
if self.classifier_location == "local":
# Send prompt to local classifier
headers = self._make_headers()
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload)
if self.api_key:
# Send prompt to Pebblo cloud if api_key is present
if self.classifier_location == "local":
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
self.update_cloud_payload(payload, pebblo_resp)
headers = self._make_headers(cloud_request=True)
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
_ = await self.amake_request(
"POST", pebblo_cloud_prompt_url, headers, payload
)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
@ -227,13 +291,45 @@ class PebbloRetrievalAPIWrapper(BaseModel):
"POST", prompt_gov_api_url, headers, prompt_payload
)
if pebblo_resp:
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
return is_valid_prompt, prompt_entities
async def acheck_prompt_validity(
self, question: str
) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
prompt_payload = {"prompt": question}
prompt_entities: dict = {"entities": {}, "entityCount": 0}
is_valid_prompt: bool = True
if self.classifier_location == "local":
headers = self._make_headers()
prompt_gov_api_url = (
f"{self.classifier_url}{Routes.prompt_governance.value}"
)
pebblo_resp = await self.amake_request(
"POST", prompt_gov_api_url, headers, prompt_payload
)
if pebblo_resp:
prompt_entities["entities"] = pebblo_resp.get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0)
return is_valid_prompt, prompt_entities
def _make_headers(self, cloud_request: bool = False) -> dict:
"""
Generate headers for the request.
@ -332,6 +428,56 @@ class PebbloRetrievalAPIWrapper(BaseModel):
payload["prompt"] = {}
payload["context"] = []
@staticmethod
async def amake_request(
method: str,
url: str,
headers: dict,
payload: Optional[dict] = None,
timeout: int = 20,
) -> Any:
"""
Make a async request to the Pebblo server/cloud API.
Args:
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
url (str): URL for the request.
headers (dict): Headers for the request.
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
timeout (int): Timeout for the request in seconds.
Returns:
Any: Response json if the request is successful.
"""
try:
client_timeout = ClientTimeout(total=timeout)
async with aiohttp.ClientSession() as asession:
async with asession.request(
method=method,
url=url,
json=payload,
headers=headers,
timeout=client_timeout,
) as response:
if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
logger.warning(f"Pebblo Server: Error {response.status}")
elif response.status >= HTTPStatus.BAD_REQUEST:
logger.warning(
f"Pebblo received an invalid payload: " f"{response.text}"
)
elif response.status != HTTPStatus.OK:
logger.warning(
f"Pebblo returned an unexpected response code: "
f"{response.status}"
)
response_json = await response.json()
return response_json
except RequestException:
logger.warning("Unable to reach server %s", url)
except Exception as e:
logger.warning("An Exception caught in amake_request: %s", e)
return None
def build_prompt_qa_payload(
self,
app_name: str,