mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community: [PebbloRetrievalQA] Implemented Async support for prompt APIs (#25748)
- **Description:** PebbloRetrievalQA: Implemented Async support for prompt APIs (classification and governance) - **Issue:** NA - **Dependencies:** NA
This commit is contained in:
parent
6703d795c5
commit
58a98c7d8a
@ -149,6 +149,7 @@ class PebbloRetrievalQA(Chain):
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
prompt_time = datetime.datetime.now().isoformat()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
@ -157,7 +158,7 @@ class PebbloRetrievalQA(Chain):
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
|
||||
_, prompt_entities = self.pb_client.check_prompt_validity(question)
|
||||
_, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
|
||||
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(
|
||||
@ -169,6 +170,18 @@ class PebbloRetrievalQA(Chain):
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
await self.pb_client.asend_prompt(
|
||||
self.app_name,
|
||||
self.retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
self.enable_prompt_gov,
|
||||
)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
|
@ -6,6 +6,8 @@ from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import ClientTimeout
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.env import get_runtime_environment
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
@ -202,6 +204,68 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
async def asend_prompt(
|
||||
self,
|
||||
app_name: str,
|
||||
retriever: VectorStoreRetriever,
|
||||
question: str,
|
||||
answer: str,
|
||||
auth_context: Optional[AuthContext],
|
||||
docs: List[Document],
|
||||
prompt_entities: Dict[str, Any],
|
||||
prompt_time: str,
|
||||
prompt_gov_enabled: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Send prompt to Pebblo server for classification.
|
||||
Then send prompt to Daxa cloud(If api_key is present).
|
||||
|
||||
Args:
|
||||
app_name (str): Name of the app.
|
||||
retriever (VectorStoreRetriever): Retriever instance.
|
||||
question (str): Question asked in the prompt.
|
||||
answer (str): Answer generated by the model.
|
||||
auth_context (Optional[AuthContext]): Authentication context.
|
||||
docs (List[Document]): List of documents retrieved.
|
||||
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||
prompt_time (str): Time when the prompt was generated.
|
||||
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||
"""
|
||||
pebblo_resp = None
|
||||
payload = self.build_prompt_qa_payload(
|
||||
app_name,
|
||||
retriever,
|
||||
question,
|
||||
answer,
|
||||
auth_context,
|
||||
docs,
|
||||
prompt_entities,
|
||||
prompt_time,
|
||||
prompt_gov_enabled,
|
||||
)
|
||||
|
||||
if self.classifier_location == "local":
|
||||
# Send prompt to local classifier
|
||||
headers = self._make_headers()
|
||||
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
|
||||
pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload)
|
||||
|
||||
if self.api_key:
|
||||
# Send prompt to Pebblo cloud if api_key is present
|
||||
if self.classifier_location == "local":
|
||||
# If classifier location is local, then response, context and prompt
|
||||
# should be fetched from pebblo_resp and replaced in payload.
|
||||
self.update_cloud_payload(payload, pebblo_resp)
|
||||
|
||||
headers = self._make_headers(cloud_request=True)
|
||||
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
|
||||
_ = await self.amake_request(
|
||||
"POST", pebblo_cloud_prompt_url, headers, payload
|
||||
)
|
||||
elif self.classifier_location == "pebblo-cloud":
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
@ -227,13 +291,45 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||
)
|
||||
if pebblo_resp:
|
||||
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||
"entityCount", 0
|
||||
)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
async def acheck_prompt_validity(
|
||||
self, question: str
|
||||
) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
"""
|
||||
prompt_payload = {"prompt": question}
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
is_valid_prompt: bool = True
|
||||
if self.classifier_location == "local":
|
||||
headers = self._make_headers()
|
||||
prompt_gov_api_url = (
|
||||
f"{self.classifier_url}{Routes.prompt_governance.value}"
|
||||
)
|
||||
pebblo_resp = await self.amake_request(
|
||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||
)
|
||||
if pebblo_resp:
|
||||
prompt_entities["entities"] = pebblo_resp.get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||
"""
|
||||
Generate headers for the request.
|
||||
@ -332,6 +428,56 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
||||
payload["prompt"] = {}
|
||||
payload["context"] = []
|
||||
|
||||
@staticmethod
|
||||
async def amake_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict,
|
||||
payload: Optional[dict] = None,
|
||||
timeout: int = 20,
|
||||
) -> Any:
|
||||
"""
|
||||
Make a async request to the Pebblo server/cloud API.
|
||||
|
||||
Args:
|
||||
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||
url (str): URL for the request.
|
||||
headers (dict): Headers for the request.
|
||||
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||
timeout (int): Timeout for the request in seconds.
|
||||
|
||||
Returns:
|
||||
Any: Response json if the request is successful.
|
||||
"""
|
||||
try:
|
||||
client_timeout = ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession() as asession:
|
||||
async with asession.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=client_timeout,
|
||||
) as response:
|
||||
if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||
logger.warning(f"Pebblo Server: Error {response.status}")
|
||||
elif response.status >= HTTPStatus.BAD_REQUEST:
|
||||
logger.warning(
|
||||
f"Pebblo received an invalid payload: " f"{response.text}"
|
||||
)
|
||||
elif response.status != HTTPStatus.OK:
|
||||
logger.warning(
|
||||
f"Pebblo returned an unexpected response code: "
|
||||
f"{response.status}"
|
||||
)
|
||||
response_json = await response.json()
|
||||
return response_json
|
||||
except RequestException:
|
||||
logger.warning("Unable to reach server %s", url)
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in amake_request: %s", e)
|
||||
return None
|
||||
|
||||
def build_prompt_qa_payload(
|
||||
self,
|
||||
app_name: str,
|
||||
|
Loading…
Reference in New Issue
Block a user