mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 23:00:00 +00:00
[Community][minor]: Added prompt governance in pebblo_retrieval (#24874)
Title: [pebblo_retrieval] Identifying entities in prompts given in PebbloRetrievalQA leading to prompt governance Description: Implemented identification of entities in the prompt using Pebblo prompt governance API. Issue: NA Dependencies: NA Add tests and docs: NA
This commit is contained in:
parent
a6add89bd4
commit
b00c0fc558
@ -8,7 +8,7 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests # type: ignore
|
||||
from langchain.chains.base import Chain
|
||||
@ -37,6 +37,7 @@ from langchain_community.chains.pebblo_retrieval.utilities import (
|
||||
CLASSIFIER_URL,
|
||||
PEBBLO_CLOUD_URL,
|
||||
PLUGIN_VERSION,
|
||||
PROMPT_GOV_URL,
|
||||
PROMPT_URL,
|
||||
get_runtime,
|
||||
)
|
||||
@ -79,6 +80,8 @@ class PebbloRetrievalQA(Chain):
|
||||
"""Flag to check if discover payload has been sent."""
|
||||
_prompt_sent: bool = False #: :meta private:
|
||||
"""Flag to check if prompt payload has been sent."""
|
||||
enable_prompt_gov: bool = True #: :meta private:
|
||||
"""Flag to check if prompt governance is enabled or not"""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -102,6 +105,8 @@ class PebbloRetrievalQA(Chain):
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key, {})
|
||||
semantic_context = inputs.get(self.semantic_context_key, {})
|
||||
_, prompt_entities = self._check_prompt_validity(question)
|
||||
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
@ -133,7 +138,12 @@ class PebbloRetrievalQA(Chain):
|
||||
for doc in docs
|
||||
if isinstance(doc, Document)
|
||||
],
|
||||
"prompt": {"data": question},
|
||||
"prompt": {
|
||||
"data": question,
|
||||
"entities": prompt_entities.get("entities", {}),
|
||||
"entityCount": prompt_entities.get("entityCount", 0),
|
||||
"prompt_gov_enabled": self.enable_prompt_gov,
|
||||
},
|
||||
"response": {
|
||||
"data": answer,
|
||||
},
|
||||
@ -144,6 +154,7 @@ class PebbloRetrievalQA(Chain):
|
||||
else [],
|
||||
"classifier_location": self.classifier_location,
|
||||
}
|
||||
|
||||
qa_payload = Qa(**qa)
|
||||
self._send_prompt(qa_payload)
|
||||
|
||||
@ -175,6 +186,9 @@ class PebbloRetrievalQA(Chain):
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._aget_docs).parameters
|
||||
)
|
||||
|
||||
_, prompt_entities = self._check_prompt_validity(question)
|
||||
|
||||
if accepts_run_manager:
|
||||
docs = await self._aget_docs(
|
||||
question, auth_context, semantic_context, run_manager=_run_manager
|
||||
@ -513,6 +527,66 @@ class PebbloRetrievalQA(Chain):
|
||||
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
|
||||
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
|
||||
|
||||
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
Check the validity of the given prompt using a remote classification service.
|
||||
|
||||
This method sends a prompt to a remote classifier service and return entities
|
||||
present in prompt or not.
|
||||
|
||||
Args:
|
||||
question (str): The prompt question to be validated.
|
||||
|
||||
Returns:
|
||||
bool: True if the prompt is valid (does not contain deny list entities),
|
||||
False otherwise.
|
||||
dict: The entities present in the prompt
|
||||
"""
|
||||
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
prompt_payload = {"prompt": question}
|
||||
is_valid_prompt: bool = True
|
||||
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
|
||||
pebblo_resp = None
|
||||
prompt_entities: dict = {"entities": {}, "entityCount": 0}
|
||||
if self.classifier_location == "local":
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
prompt_gov_api_url,
|
||||
headers=headers,
|
||||
json=prompt_payload,
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
logger.debug("prompt-payload: %s", prompt_payload)
|
||||
logger.debug(
|
||||
"send_prompt[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_resp.request.body if pebblo_resp.request.body else []
|
||||
)
|
||||
),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
|
||||
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
|
||||
prompt_entities["entityCount"] = pebblo_resp.json().get(
|
||||
"entityCount", 0
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
return is_valid_prompt, prompt_entities
|
||||
|
||||
@classmethod
|
||||
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
|
||||
llm_dict = llm.__dict__
|
||||
|
@ -133,7 +133,10 @@ class Context(BaseModel):
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
data: str
|
||||
data: Optional[Union[list, str]]
|
||||
entityCount: Optional[int]
|
||||
entities: Optional[dict]
|
||||
prompt_gov_enabled: Optional[bool]
|
||||
|
||||
|
||||
class Qa(BaseModel):
|
||||
|
@ -15,6 +15,7 @@ CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||
|
||||
PROMPT_URL = "/v1/prompt"
|
||||
PROMPT_GOV_URL = "/v1/prompt/governance"
|
||||
APP_DISCOVER_URL = "/v1/app/discover"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user