community: Refactor PebbloRetrievalQA (#25583)

**Refactor PebbloRetrievalQA**
  - Created `APIWrapper` and moved API logic into it.
  - Created smaller functions/methods for better readability.
  - Properly read environment variables.
  - Removed unused code.
  - Updated models

**Issue:** NA
**Dependencies:** NA
**tests**:  NA
This commit is contained in:
Rajendra Kadam 2024-08-22 21:21:21 +05:30 committed by GitHub
parent 1f1679e960
commit 4ff2f4499e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 391 additions and 317 deletions

View File

@ -5,12 +5,9 @@ against a vector database.
import datetime import datetime
import inspect import inspect
import json
import logging import logging
from http import HTTPStatus from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
import requests # type: ignore
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import ( from langchain_core.callbacks import (
@ -29,16 +26,14 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
from langchain_community.chains.pebblo_retrieval.models import ( from langchain_community.chains.pebblo_retrieval.models import (
App, App,
AuthContext, AuthContext,
Qa, ChainInfo,
Model,
SemanticContext, SemanticContext,
VectorDB,
) )
from langchain_community.chains.pebblo_retrieval.utilities import ( from langchain_community.chains.pebblo_retrieval.utilities import (
APP_DISCOVER_URL,
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION, PLUGIN_VERSION,
PROMPT_GOV_URL, PebbloRetrievalAPIWrapper,
PROMPT_URL,
get_runtime, get_runtime,
) )
@ -72,16 +67,18 @@ class PebbloRetrievalQA(Chain):
"""Description of app.""" """Description of app."""
api_key: Optional[str] = None #: :meta private: api_key: Optional[str] = None #: :meta private:
"""Pebblo cloud API key for app.""" """Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private: classifier_url: Optional[str] = None #: :meta private:
"""Classifier endpoint.""" """Classifier endpoint."""
classifier_location: str = "local" #: :meta private: classifier_location: str = "local" #: :meta private:
"""Classifier location. It could be either of 'local' or 'pebblo-cloud'.""" """Classifier location. It could be either of 'local' or 'pebblo-cloud'."""
_discover_sent: bool = False #: :meta private: _discover_sent: bool = False #: :meta private:
"""Flag to check if discover payload has been sent.""" """Flag to check if discover payload has been sent."""
_prompt_sent: bool = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
enable_prompt_gov: bool = True #: :meta private: enable_prompt_gov: bool = True #: :meta private:
"""Flag to check if prompt governance is enabled or not""" """Flag to check if prompt governance is enabled or not"""
pb_client: PebbloRetrievalAPIWrapper = Field(
default_factory=PebbloRetrievalAPIWrapper
)
"""Pebblo Retrieval API client"""
def _call( def _call(
self, self,
@ -100,12 +97,11 @@ class PebbloRetrievalQA(Chain):
answer, docs = res['result'], res['source_documents'] answer, docs = res['result'], res['source_documents']
""" """
prompt_time = datetime.datetime.now().isoformat() prompt_time = datetime.datetime.now().isoformat()
PebbloRetrievalQA.set_prompt_sent(value=False)
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key] question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key, {}) auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key, {}) semantic_context = inputs.get(self.semantic_context_key)
_, prompt_entities = self._check_prompt_validity(question) _, prompt_entities = self.pb_client.check_prompt_validity(question)
accepts_run_manager = ( accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters "run_manager" in inspect.signature(self._get_docs).parameters
@ -120,43 +116,17 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child() input_documents=docs, question=question, callbacks=_run_manager.get_child()
) )
qa = { self.pb_client.send_prompt(
"name": self.app_name, self.app_name,
"context": [ self.retriever,
{ question,
"retrieved_from": doc.metadata.get( answer,
"full_path", doc.metadata.get("source") auth_context,
), docs,
"doc": doc.page_content, prompt_entities,
"vector_db": self.retriever.vectorstore.__class__.__name__, prompt_time,
**( self.enable_prompt_gov,
{"pb_checksum": doc.metadata.get("pb_checksum")} )
if doc.metadata.get("pb_checksum")
else {}
),
}
for doc in docs
if isinstance(doc, Document)
],
"prompt": {
"data": question,
"entities": prompt_entities.get("entities", {}),
"entityCount": prompt_entities.get("entityCount", 0),
"prompt_gov_enabled": self.enable_prompt_gov,
},
"response": {
"data": answer,
},
"prompt_time": prompt_time,
"user": auth_context.user_id if auth_context else "unknown",
"user_identities": auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
"classifier_location": self.classifier_location,
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
if self.return_source_documents: if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs} return {self.output_key: answer, "source_documents": docs}
@ -187,7 +157,7 @@ class PebbloRetrievalQA(Chain):
"run_manager" in inspect.signature(self._aget_docs).parameters "run_manager" in inspect.signature(self._aget_docs).parameters
) )
_, prompt_entities = self._check_prompt_validity(question) _, prompt_entities = self.pb_client.check_prompt_validity(question)
if accepts_run_manager: if accepts_run_manager:
docs = await self._aget_docs( docs = await self._aget_docs(
@ -243,7 +213,7 @@ class PebbloRetrievalQA(Chain):
chain_type: str = "stuff", chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None, chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL, classifier_url: Optional[str] = None,
classifier_location: str = "local", classifier_location: str = "local",
**kwargs: Any, **kwargs: Any,
) -> "PebbloRetrievalQA": ) -> "PebbloRetrievalQA":
@ -263,14 +233,14 @@ class PebbloRetrievalQA(Chain):
llm=llm, llm=llm,
**kwargs, **kwargs,
) )
# initialize Pebblo API client
PebbloRetrievalQA._send_discover( pb_client = PebbloRetrievalAPIWrapper(
app,
api_key=api_key, api_key=api_key,
classifier_url=classifier_url,
classifier_location=classifier_location, classifier_location=classifier_location,
classifier_url=classifier_url,
) )
# send app discovery request
pb_client.send_app_discover(app)
return cls( return cls(
combine_documents_chain=combine_documents_chain, combine_documents_chain=combine_documents_chain,
app_name=app_name, app_name=app_name,
@ -279,6 +249,7 @@ class PebbloRetrievalQA(Chain):
api_key=api_key, api_key=api_key,
classifier_url=classifier_url, classifier_url=classifier_url,
classifier_location=classifier_location, classifier_location=classifier_location,
pb_client=pb_client,
**kwargs, **kwargs,
) )
@ -346,259 +317,36 @@ class PebbloRetrievalQA(Chain):
) )
return app return app
@staticmethod
def _send_discover(
app: App,
api_key: Optional[str],
classifier_url: str,
classifier_location: str,
) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
if classifier_location == "local":
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if api_key:
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: cloud %s", e)
@classmethod @classmethod
def set_discover_sent(cls) -> None: def set_discover_sent(cls) -> None:
cls._discover_sent = True cls._discover_sent = True
@classmethod @classmethod
def set_prompt_sent(cls, value: bool = True) -> None: def get_chain_details(
cls._prompt_sent = value cls, llm: BaseLanguageModel, **kwargs: Any
) -> List[ChainInfo]:
def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
pebblo_resp = None
payload = qa_payload.dict(exclude_unset=True)
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
app_discover_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug("prompt-payload: %s", payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
"Received unexpected HTTP response code:"
+ f"{pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
if self.api_key:
if self.classifier_location == "local":
if pebblo_resp:
resp = json.loads(pebblo_resp.text)
if resp:
payload["response"].update(
resp.get("retrieval_data", {}).get("response", {})
)
payload["response"].pop("data")
payload["prompt"].update(
resp.get("retrieval_data", {}).get("prompt", {})
)
payload["prompt"].pop("data")
context = payload["context"]
for context_data in context:
context_data.pop("doc")
payload["context"] = context
else:
payload["response"] = {}
payload["prompt"] = {}
payload["context"] = []
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
try:
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=payload,
timeout=20,
)
logger.debug(
"send_prompt[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def _check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
""" """
Check the validity of the given prompt using a remote classification service. Get chain details.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args: Args:
question (str): The prompt question to be validated. llm (BaseLanguageModel): Language model instance.
**kwargs: Additional keyword arguments.
Returns: Returns:
bool: True if the prompt is valid (does not contain deny list entities), List[ChainInfo]: Chain details.
False otherwise.
dict: The entities present in the prompt
""" """
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
prompt_payload = {"prompt": question}
is_valid_prompt: bool = True
prompt_gov_api_url = f"{self.classifier_url}{PROMPT_GOV_URL}"
pebblo_resp = None
prompt_entities: dict = {"entities": {}, "entityCount": 0}
if self.classifier_location == "local":
try:
pebblo_resp = requests.post(
prompt_gov_api_url,
headers=headers,
json=prompt_payload,
timeout=20,
)
logger.debug("prompt-payload: %s", prompt_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(
len(
pebblo_resp.request.body if pebblo_resp.request.body else []
)
),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
return is_valid_prompt, prompt_entities
@classmethod
def get_chain_details(cls, llm: BaseLanguageModel, **kwargs): # type: ignore
llm_dict = llm.__dict__ llm_dict = llm.__dict__
chain = [ chains = [
{ ChainInfo(
"name": cls.__name__, name=cls.__name__,
"model": { model=Model(
"name": llm_dict.get("model_name", llm_dict.get("model")), name=llm_dict.get("model_name", llm_dict.get("model")),
"vendor": llm.__class__.__name__, vendor=llm.__class__.__name__,
}, ),
"vector_dbs": [ vector_dbs=[
{ VectorDB(
"name": kwargs["retriever"].vectorstore.__class__.__name__, name=kwargs["retriever"].vectorstore.__class__.__name__,
"embedding_model": str( embedding_model=str(
kwargs["retriever"].vectorstore._embeddings.model kwargs["retriever"].vectorstore._embeddings.model
) )
if hasattr(kwargs["retriever"].vectorstore, "_embeddings") if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
@ -607,8 +355,8 @@ class PebbloRetrievalQA(Chain):
if hasattr(kwargs["retriever"].vectorstore, "_embedding") if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None else None
), ),
} )
], ],
}, ),
] ]
return chain return chains

View File

@ -109,7 +109,7 @@ class VectorDB(BaseModel):
embedding_model: Optional[str] = None embedding_model: Optional[str] = None
class Chains(BaseModel): class ChainInfo(BaseModel):
name: str name: str
model: Optional[Model] model: Optional[Model]
vector_dbs: Optional[List[VectorDB]] vector_dbs: Optional[List[VectorDB]]
@ -121,7 +121,7 @@ class App(BaseModel):
description: Optional[str] description: Optional[str]
runtime: Runtime runtime: Runtime
framework: Framework framework: Framework
chains: List[Chains] chains: List[ChainInfo]
plugin_version: str plugin_version: str
@ -134,9 +134,9 @@ class Context(BaseModel):
class Prompt(BaseModel): class Prompt(BaseModel):
data: Optional[Union[list, str]] data: Optional[Union[list, str]]
entityCount: Optional[int] entityCount: Optional[int] = None
entities: Optional[dict] entities: Optional[dict] = None
prompt_gov_enabled: Optional[bool] prompt_gov_enabled: Optional[bool] = None
class Qa(BaseModel): class Qa(BaseModel):

View File

@ -1,22 +1,43 @@
import json
import logging import logging
import os import os
import platform import platform
from typing import Tuple from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_core.env import get_runtime_environment from langchain_core.env import get_runtime_environment
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStoreRetriever
from requests import Response, request
from requests.exceptions import RequestException
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Context,
Framework,
Prompt,
Qa,
Runtime,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PLUGIN_VERSION = "0.1.1" PLUGIN_VERSION = "0.1.1"
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000") _DEFAULT_CLASSIFIER_URL = "http://localhost:8000"
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") _DEFAULT_PEBBLO_CLOUD_URL = "https://api.daxa.ai"
PROMPT_URL = "/v1/prompt"
PROMPT_GOV_URL = "/v1/prompt/governance" class Routes(str, Enum):
APP_DISCOVER_URL = "/v1/app/discover" """Routes available for the Pebblo API as enumerator."""
retrieval_app_discover = "/v1/app/discover"
prompt = "/v1/prompt"
prompt_governance = "/v1/prompt/governance"
def get_runtime() -> Tuple[Framework, Runtime]: def get_runtime() -> Tuple[Framework, Runtime]:
@ -64,3 +85,308 @@ def get_ip() -> str:
except Exception: except Exception:
public_ip = socket.gethostbyname("localhost") public_ip = socket.gethostbyname("localhost")
return public_ip return public_ip
class PebbloRetrievalAPIWrapper(BaseModel):
"""Wrapper for Pebblo Retrieval API."""
api_key: Optional[str] # Use SecretStr
"""API key for Pebblo Cloud"""
classifier_location: str = "local"
"""Location of the classifier, local or cloud. Defaults to 'local'"""
classifier_url: Optional[str]
"""URL of the Pebblo Classifier"""
cloud_url: Optional[str]
"""URL of the Pebblo Cloud"""
def __init__(self, **kwargs: Any):
"""Validate that api key in environment."""
kwargs["api_key"] = get_from_dict_or_env(
kwargs, "api_key", "PEBBLO_API_KEY", ""
)
kwargs["classifier_url"] = get_from_dict_or_env(
kwargs, "classifier_url", "PEBBLO_CLASSIFIER_URL", _DEFAULT_CLASSIFIER_URL
)
kwargs["cloud_url"] = get_from_dict_or_env(
kwargs, "cloud_url", "PEBBLO_CLOUD_URL", _DEFAULT_PEBBLO_CLOUD_URL
)
super().__init__(**kwargs)
def send_app_discover(self, app: App) -> None:
"""
Send app discovery request to Pebblo server & cloud.
Args:
app (App): App instance to be discovered.
"""
pebblo_resp = None
payload = app.dict(exclude_unset=True)
if self.classifier_location == "local":
# Send app details to local classifier
headers = self._make_headers()
app_discover_url = f"{self.classifier_url}{Routes.retrieval_app_discover}"
pebblo_resp = self.make_request("POST", app_discover_url, headers, payload)
if self.api_key:
# Send app details to Pebblo cloud if api_key is present
headers = self._make_headers(cloud_request=True)
if pebblo_resp:
pebblo_server_version = json.loads(pebblo_resp.text).get(
"pebblo_server_version"
)
payload.update({"pebblo_server_version": pebblo_server_version})
payload.update({"pebblo_client_version": PLUGIN_VERSION})
pebblo_cloud_url = f"{self.cloud_url}{Routes.retrieval_app_discover}"
_ = self.make_request("POST", pebblo_cloud_url, headers, payload)
def send_prompt(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> None:
"""
Send prompt to Pebblo server for classification.
Then send prompt to Daxa cloud(If api_key is present).
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
"""
pebblo_resp = None
payload = self.build_prompt_qa_payload(
app_name,
retriever,
question,
answer,
auth_context,
docs,
prompt_entities,
prompt_time,
prompt_gov_enabled,
)
if self.classifier_location == "local":
# Send prompt to local classifier
headers = self._make_headers()
prompt_url = f"{self.classifier_url}{Routes.prompt}"
pebblo_resp = self.make_request("POST", prompt_url, headers, payload)
if self.api_key:
# Send prompt to Pebblo cloud if api_key is present
if self.classifier_location == "local":
# If classifier location is local, then response, context and prompt
# should be fetched from pebblo_resp and replaced in payload.
pebblo_resp = pebblo_resp.json() if pebblo_resp else None
self.update_cloud_payload(payload, pebblo_resp)
headers = self._make_headers(cloud_request=True)
pebblo_cloud_prompt_url = f"{self.cloud_url}{Routes.prompt}"
_ = self.make_request("POST", pebblo_cloud_prompt_url, headers, payload)
elif self.classifier_location == "pebblo-cloud":
logger.warning("API key is missing for sending prompt to Pebblo cloud.")
raise NameError("API key is missing for sending prompt to Pebblo cloud.")
def check_prompt_validity(self, question: str) -> Tuple[bool, Dict[str, Any]]:
"""
Check the validity of the given prompt using a remote classification service.
This method sends a prompt to a remote classifier service and return entities
present in prompt or not.
Args:
question (str): The prompt question to be validated.
Returns:
bool: True if the prompt is valid (does not contain deny list entities),
False otherwise.
dict: The entities present in the prompt
"""
prompt_payload = {"prompt": question}
prompt_entities: dict = {"entities": {}, "entityCount": 0}
is_valid_prompt: bool = True
if self.classifier_location == "local":
headers = self._make_headers()
prompt_gov_api_url = f"{self.classifier_url}{Routes.prompt_governance}"
pebblo_resp = self.make_request(
"POST", prompt_gov_api_url, headers, prompt_payload
)
if pebblo_resp:
logger.debug(f"pebblo_resp.json() {pebblo_resp.json()}")
prompt_entities["entities"] = pebblo_resp.json().get("entities", {})
prompt_entities["entityCount"] = pebblo_resp.json().get(
"entityCount", 0
)
return is_valid_prompt, prompt_entities
def _make_headers(self, cloud_request: bool = False) -> dict:
"""
Generate headers for the request.
args:
cloud_request (bool): flag indicating whether the request is for Pebblo
cloud.
returns:
dict: Headers for the request.
"""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
if cloud_request:
# Add API key for Pebblo cloud request
if self.api_key:
headers.update({"x-api-key": self.api_key})
else:
logger.warning("API key is missing for Pebblo cloud request.")
return headers
@staticmethod
def make_request(
method: str,
url: str,
headers: dict,
payload: Optional[dict] = None,
timeout: int = 20,
) -> Optional[Response]:
"""
Make a request to the Pebblo server/cloud API.
Args:
method (str): HTTP method (GET, POST, PUT, DELETE, etc.).
url (str): URL for the request.
headers (dict): Headers for the request.
payload (Optional[dict]): Payload for the request (for POST, PUT, etc.).
timeout (int): Timeout for the request in seconds.
Returns:
Optional[Response]: Response object if the request is successful.
"""
try:
response = request(
method=method, url=url, headers=headers, json=payload, timeout=timeout
)
logger.debug(
"Request: method %s, url %s, len %s response status %s",
method,
response.request.url,
str(len(response.request.body if response.request.body else [])),
str(response.status_code),
)
if response.status_code >= HTTPStatus.INTERNAL_SERVER_ERROR:
logger.warning(f"Pebblo Server: Error {response.status_code}")
elif response.status_code >= HTTPStatus.BAD_REQUEST:
logger.warning(f"Pebblo received an invalid payload: {response.text}")
elif response.status_code != HTTPStatus.OK:
logger.warning(
f"Pebblo returned an unexpected response code: "
f"{response.status_code}"
)
return response
except RequestException:
logger.warning("Unable to reach server %s", url)
except Exception as e:
logger.warning("An Exception caught in make_request: %s", e)
return None
@staticmethod
def update_cloud_payload(payload: dict, pebblo_resp: Optional[dict]) -> None:
"""
Update the payload with response, prompt and context from Pebblo response.
Args:
payload (dict): Payload to be updated.
pebblo_resp (Optional[dict]): Response from Pebblo server.
"""
if pebblo_resp:
# Update response, prompt and context from pebblo response
response = payload.get("response", {})
response.update(pebblo_resp.get("retrieval_data", {}).get("response", {}))
response.pop("data", None)
prompt = payload.get("prompt", {})
prompt.update(pebblo_resp.get("retrieval_data", {}).get("prompt", {}))
prompt.pop("data", None)
context = payload.get("context", [])
for context_data in context:
context_data.pop("doc", None)
else:
payload["response"] = {}
payload["prompt"] = {}
payload["context"] = []
def build_prompt_qa_payload(
self,
app_name: str,
retriever: VectorStoreRetriever,
question: str,
answer: str,
auth_context: Optional[AuthContext],
docs: List[Document],
prompt_entities: Dict[str, Any],
prompt_time: str,
prompt_gov_enabled: bool = False,
) -> dict:
"""
Build the QA payload for the prompt.
Args:
app_name (str): Name of the app.
retriever (VectorStoreRetriever): Retriever instance.
question (str): Question asked in the prompt.
answer (str): Answer generated by the model.
auth_context (Optional[AuthContext]): Authentication context.
docs (List[Document]): List of documents retrieved.
prompt_entities (Dict[str, Any]): Entities present in the prompt.
prompt_time (str): Time when the prompt was generated.
prompt_gov_enabled (bool): Whether prompt governance is enabled.
Returns:
dict: The QA payload for the prompt.
"""
qa = Qa(
name=app_name,
context=[
Context(
retrieved_from=doc.metadata.get(
"full_path", doc.metadata.get("source")
),
doc=doc.page_content,
vector_db=retriever.vectorstore.__class__.__name__,
pb_checksum=doc.metadata.get("pb_checksum"),
)
for doc in docs
if isinstance(doc, Document)
],
prompt=Prompt(
data=question,
entities=prompt_entities.get("entities", {}),
entityCount=prompt_entities.get("entityCount", 0),
prompt_gov_enabled=prompt_gov_enabled,
),
response=Prompt(data=answer),
prompt_time=prompt_time,
user=auth_context.user_id if auth_context else "unknown",
user_identities=auth_context.user_auth
if auth_context and hasattr(auth_context, "user_auth")
else [],
classifier_location=self.classifier_location,
)
return qa.dict(exclude_unset=True)