mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-24 15:43:54 +00:00
community: [PebbloRetrievalQA] Implemented Async support for prompt APIs (#25748)
- **Description:** PebbloRetrievalQA: Implemented Async support for prompt APIs (classification and governance) - **Issue:** NA - **Dependencies:** NA
This commit is contained in:
parent
6703d795c5
commit
58a98c7d8a
@ -149,6 +149,7 @@ class PebbloRetrievalQA(Chain):
|
|||||||
res = indexqa({'query': 'This is my query'})
|
res = indexqa({'query': 'This is my query'})
|
||||||
answer, docs = res['result'], res['source_documents']
|
answer, docs = res['result'], res['source_documents']
|
||||||
"""
|
"""
|
||||||
|
prompt_time = datetime.datetime.now().isoformat()
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
auth_context = inputs.get(self.auth_context_key)
|
auth_context = inputs.get(self.auth_context_key)
|
||||||
@ -157,7 +158,7 @@ class PebbloRetrievalQA(Chain):
|
|||||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||||
)
|
)
|
||||||
|
|
||||||
_, prompt_entities = self.pb_client.check_prompt_validity(question)
|
_, prompt_entities = await self.pb_client.acheck_prompt_validity(question)
|
||||||
|
|
||||||
if accepts_run_manager:
|
if accepts_run_manager:
|
||||||
docs = await self._aget_docs(
|
docs = await self._aget_docs(
|
||||||
@ -169,6 +170,18 @@ class PebbloRetrievalQA(Chain):
|
|||||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await self.pb_client.asend_prompt(
|
||||||
|
self.app_name,
|
||||||
|
self.retriever,
|
||||||
|
question,
|
||||||
|
answer,
|
||||||
|
auth_context,
|
||||||
|
docs,
|
||||||
|
prompt_entities,
|
||||||
|
prompt_time,
|
||||||
|
self.enable_prompt_gov,
|
||||||
|
)
|
||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
else:
|
||||||
|
@ -6,6 +6,8 @@ from enum import Enum
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import ClientTimeout
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.env import get_runtime_environment
|
from langchain_core.env import get_runtime_environment
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
@ -202,6 +204,68 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
|||||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
|
|
||||||
|
async def asend_prompt(
|
||||||
|
self,
|
||||||
|
app_name: str,
|
||||||
|
retriever: VectorStoreRetriever,
|
||||||
|
question: str,
|
||||||
|
answer: str,
|
||||||
|
auth_context: Optional[AuthContext],
|
||||||
|
docs: List[Document],
|
||||||
|
prompt_entities: Dict[str, Any],
|
||||||
|
prompt_time: str,
|
||||||
|
prompt_gov_enabled: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Send prompt to Pebblo server for classification.
|
||||||
|
Then send prompt to Daxa cloud(If api_key is present).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_name (str): Name of the app.
|
||||||
|
retriever (VectorStoreRetriever): Retriever instance.
|
||||||
|
question (str): Question asked in the prompt.
|
||||||
|
answer (str): Answer generated by the model.
|
||||||
|
auth_context (Optional[AuthContext]): Authentication context.
|
||||||
|
docs (List[Document]): List of documents retrieved.
|
||||||
|
prompt_entities (Dict[str, Any]): Entities present in the prompt.
|
||||||
|
prompt_time (str): Time when the prompt was generated.
|
||||||
|
prompt_gov_enabled (bool): Whether prompt governance is enabled.
|
||||||
|
"""
|
||||||
|
pebblo_resp = None
|
||||||
|
payload = self.build_prompt_qa_payload(
|
||||||
|
app_name,
|
||||||
|
retriever,
|
||||||
|
question,
|
||||||
|
answer,
|
||||||
|
auth_context,
|
||||||
|
docs,
|
||||||
|
prompt_entities,
|
||||||
|
prompt_time,
|
||||||
|
prompt_gov_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
# Send prompt to local classifier
|
||||||
|
headers = self._make_headers()
|
||||||
|
prompt_url = f"{self.classifier_url}{Routes.prompt.value}"
|
||||||
|
pebblo_resp = await self.amake_request("POST", prompt_url, headers, payload)
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
# Send prompt to Pebblo cloud if api_key is present
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
# If classifier location is local, then response, context and prompt
|
||||||
|
# should be fetched from pebblo_resp and replaced in payload.
|
||||||
|
self.update_cloud_payload(payload, pebblo_resp)
|
||||||
|
|
||||||
|
headers = self._make_headers(cloud_request=True)
|
||||||
|
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt.value}"
|
||||||
|
_ = await self.amake_request(
|
||||||
|
"POST", pebblo_cloud_prompt_url, headers, payload
|
||||||
|
)
|
||||||
|
elif self.classifier_location == "pebblo-cloud":
|
||||||
|
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
|
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||||
|
|
||||||
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Check the validity of the given prompt using a remote classification service.
|
Check the validity of the given prompt using a remote classification service.
|
||||||
@ -227,13 +291,45 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
|||||||
"POST", prompt_gov_api_url, headers, prompt_payload
|
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||||
)
|
)
|
||||||
if pebblo_resp:
|
if pebblo_resp:
|
||||||
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
|
||||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||||
"entityCount", 0
|
"entityCount", 0
|
||||||
)
|
)
|
||||||
return is_valid_prompt, prompt_entities
|
return is_valid_prompt, prompt_entities
|
||||||
|
|
||||||
|
async def acheck_prompt_validity(
|
||||||
|
self, question: str
|
||||||
|
) -> Tuple[bool, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Check the validity of the given prompt using a remote classification service.
|
||||||
|
|
||||||
|
This method sends a prompt to a remote classifier service and return entities
|
||||||
|
present in prompt or not.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
question (str): The prompt question to be validated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the prompt is valid (does not contain deny list entities),
|
||||||
|
False otherwise.
|
||||||
|
dict: The entities present in the prompt
|
||||||
|
"""
|
||||||
|
prompt_payload = {"prompt": question}
|
||||||
|
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||||
|
is_valid_prompt: bool = True
|
||||||
|
if self.classifier_location == "local":
|
||||||
|
headers = self._make_headers()
|
||||||
|
prompt_gov_api_url = (
|
||||||
|
f"{self.classifier_url}{Routes.prompt_governance.value}"
|
||||||
|
)
|
||||||
|
pebblo_resp = await self.amake_request(
|
||||||
|
"POST", prompt_gov_api_url, headers, prompt_payload
|
||||||
|
)
|
||||||
|
if pebblo_resp:
|
||||||
|
prompt_entities["entities"] = pebblo_resp.get("entities", {})
|
||||||
|
prompt_entities["entityCount"] = pebblo_resp.get("entityCount", 0)
|
||||||
|
return is_valid_prompt, prompt_entities
|
||||||
|
|
||||||
def _make_headers(self, cloud_request: bool = False) -> dict:
|
def _make_headers(self, cloud_request: bool = False) -> dict:
|
||||||
"""
|
"""
|
||||||
Generate headers for the request.
|
Generate headers for the request.
|
||||||
@ -332,6 +428,56 @@ class PebbloRetrievalAPIWrapper(BaseModel):
|
|||||||
payload["prompt"] = {}
|
payload["prompt"] = {}
|
||||||
payload["context"] = []
|
payload["context"] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def amake_request(
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
headers: dict,
|
||||||
|
payload: Optional[dict] = None,
|
||||||
|
timeout: int = 20,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Make a async request to the Pebblo server/cloud API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
|
||||||
|
url (str): URL for the request.
|
||||||
|
headers (dict): Headers for the request.
|
||||||
|
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
|
||||||
|
timeout (int): Timeout for the request in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: Response json if the request is successful.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
client_timeout = ClientTimeout(total=timeout)
|
||||||
|
async with aiohttp.ClientSession() as asession:
|
||||||
|
async with asession.request(
|
||||||
|
method=method,
|
||||||
|
url=url,
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
timeout=client_timeout,
|
||||||
|
) as response:
|
||||||
|
if response.status >= HTTPStatus.INTERNAL_SERVER_ERROR:
|
||||||
|
logger.warning(f"Pebblo Server: Error {response.status}")
|
||||||
|
elif response.status >= HTTPStatus.BAD_REQUEST:
|
||||||
|
logger.warning(
|
||||||
|
f"Pebblo received an invalid payload: " f"{response.text}"
|
||||||
|
)
|
||||||
|
elif response.status != HTTPStatus.OK:
|
||||||
|
logger.warning(
|
||||||
|
f"Pebblo returned an unexpected response code: "
|
||||||
|
f"{response.status}"
|
||||||
|
)
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json
|
||||||
|
except RequestException:
|
||||||
|
logger.warning("Unable to reach server %s", url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in amake_request: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
def build_prompt_qa_payload(
|
def build_prompt_qa_payload(
|
||||||
self,
|
self,
|
||||||
app_name: str,
|
app_name: str,
|
||||||
|
Loading…
Reference in New Issue
Block a user