community[minor]: Enable retrieval api calls in PebbloRetrievalQA (#21958)

Description: Enable app discovery and Prompt/Response apis in
PebbloSafeRetrieval
Documentation: NA
Unit test: N/A

---------

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
This commit is contained in:
Rahul Triptahi 2024-06-04 22:48:50 +05:30 committed by GitHub
parent 8fd231086e
commit 77ad857934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 433 additions and 11 deletions

View File

@ -3,9 +3,13 @@ Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answeri
against a vector database.
"""
import datetime
import inspect
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
import requests # type: ignore
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
@ -22,9 +26,21 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Qa,
SemanticContext,
)
from langchain_community.chains.pebblo_retrieval.utilities import (
APP_DISCOVER_URL,
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_URL,
get_runtime,
)
logger = logging.getLogger(__name__)
class PebbloRetrievalQA(Chain):
@ -46,6 +62,20 @@ class PebbloRetrievalQA(Chain):
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""
app_name: str #: :meta private:
"""App name."""
owner: str #: :meta private:
"""Owner of app."""
description: str #: :meta private:
"""Description of app."""
api_key: Optional[str] = None #: :meta private:
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
def _call(
self,
@ -63,10 +93,11 @@ class PebbloRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
@ -80,6 +111,32 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
qa = {
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
for doc in docs
if isinstance(doc, Document)
],
"prompt": {"data": question},
"response": {
"data": answer,
},
"prompt_time": prompt_time,
"user": auth_context.user_id if auth_context else "unknown",
"user_identities": auth_context.user_auth
if "user_auth" in dict(auth_context)
else []
if auth_context
else [],
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
@ -158,8 +215,13 @@ class PebbloRetrievalQA(Chain):
def from_chain_type(
cls,
llm: BaseLanguageModel,
app_name: str,
description: str,
owner: str,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
@ -169,7 +231,29 @@ class PebbloRetrievalQA(Chain):
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
# generate app
app = PebbloRetrievalQA._get_app_details(
app_name=app_name,
description=description,
owner=owner,
llm=llm,
**kwargs,
)
PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
)
return cls(
combine_documents_chain=combine_documents_chain,
app_name=app_name,
owner=owner,
description=description,
api_key=api_key,
classifier_url=classifier_url,
**kwargs,
)
@validator("retriever", pre=True, always=True)
def validate_vectorstore(
@ -216,3 +300,182 @@ class PebbloRetrievalQA(Chain):
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
@staticmethod
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
"""Fetch app details. Internal method.
Returns:
App: App details.
"""
framework, runtime = get_runtime()
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
app = App(
name=app_name,
owner=owner,
description=description,
runtime=runtime,
framework=framework,
chains=chains,
plugin_version=PLUGIN_VERSION,
)
return app
@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if api_key:
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: cloud %s", e)
@classmethod
def set_discover_sent(cls) -> None:
cls._discover_sent = True
@classmethod
def set_prompt_sent(cls) -> None:
cls._prompt_sent = True
def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if self.api_key:
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
timeout=20,
)
logger.debug(
"send_prompt[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
@classmethod
def get_chain_details(cls, llm, **kwargs): # type: ignore
llm_dict = llm.__dict__
chain = [
{
"name": cls.__name__,
"model": {
"name": llm_dict.get("model_name", llm_dict.get("model")),
"vendor": llm.__class__.__name__,
},
"vector_dbs": [
{
"name": kwargs["retriever"].vectorstore.__class__.__name__,
"embedding_model": str(
kwargs["retriever"].vectorstore._embeddings.model
)
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
else (
str(kwargs["retriever"].vectorstore._embedding.model)
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None
),
}
],
}
]
return chain

View File

@ -1,13 +1,13 @@
"""
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
This module contains methods for applying Identity and Semantic Enforcement filters
in the PebbloRetrievalQA chain.
These filters are used to control the retrieval of documents based on authorization and
semantic context.
The Identity Enforcement filter ensures that only authorized identities can access
certain documents, while the Semantic Enforcement filter controls document retrieval
based on semantic context.
This module contains methods for applying Identity and Semantic Enforcement filters
in the PebbloRetrievalQA chain.
These filters are used to control the retrieval of documents based on authorization and
semantic context.
The Identity Enforcement filter ensures that only authorized identities can access
certain documents, while the Semantic Enforcement filter controls document retrieval
based on semantic context.
The methods in this module are designed to work with different types of vector stores.
"""

View File

@ -60,3 +60,86 @@ class ChainInput(BaseModel):
base_dict["auth_context"] = self.auth_context
base_dict["semantic_context"] = self.semantic_context
return base_dict
class Runtime(BaseModel):
"""
OS, language details
"""
type: Optional[str] = ""
host: str
path: str
ip: Optional[str] = ""
platform: str
os: str
os_version: str
language: str
language_version: str
runtime: Optional[str] = ""
class Framework(BaseModel):
"""
Langchain framework details
"""
name: str
version: str
class Model(BaseModel):
vendor: Optional[str]
name: Optional[str]
class PkgInfo(BaseModel):
project_home_page: Optional[str]
documentation_url: Optional[str]
pypi_url: Optional[str]
liscence_type: Optional[str]
installed_via: Optional[str]
location: Optional[str]
class VectorDB(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
location: Optional[str] = None
embedding_model: Optional[str] = None
class Chains(BaseModel):
name: str
model: Optional[Model]
vector_dbs: Optional[List[VectorDB]]
class App(BaseModel):
name: str
owner: str
description: Optional[str]
runtime: Runtime
framework: Framework
chains: List[Chains]
plugin_version: str
class Context(BaseModel):
retrieved_from: Optional[str]
doc: Optional[str]
vector_db: str
class Prompt(BaseModel):
data: str
class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
prompt_time: str
user: str
user_identities: Optional[List[str]]

View File

@ -0,0 +1,65 @@
import logging
import os
import platform
from typing import Tuple
from langchain_core.env import get_runtime_environment
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
logger = logging.getLogger(__name__)
PLUGIN_VERSION = "0.1.1"
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
PROMPT_URL = "/v1/prompt"
APP_DISCOVER_URL = "/v1/app/discover"
def get_runtime() -> Tuple[Framework, Runtime]:
"""Fetch the current Framework and Runtime details.
Returns:
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
"""
runtime_env = get_runtime_environment()
framework = Framework(
name="langchain", version=runtime_env.get("library_version", None)
)
uname = platform.uname()
runtime = Runtime(
host=uname.node,
path=os.environ["PWD"],
platform=runtime_env.get("platform", "unknown"),
os=uname.system,
os_version=uname.version,
ip=get_ip(),
language=runtime_env.get("runtime", "unknown"),
language_version=runtime_env.get("runtime_version", "unknown"),
)
if "Darwin" in runtime.os:
runtime.type = "desktop"
runtime.runtime = "Mac OSX"
logger.debug(f"framework {framework}")
logger.debug(f"runtime {runtime}")
return framework, runtime
def get_ip() -> str:
"""Fetch local runtime ip address.
Returns:
str: IP address
"""
import socket # lazy imports
host = socket.gethostname()
try:
public_ip = socket.gethostbyname(host)
except Exception:
public_ip = socket.gethostbyname("localhost")
return public_ip

View File

@ -72,7 +72,12 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
Create a PebbloRetrievalQA instance
"""
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(), chain_type="stuff", retriever=retriever
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
return pebblo_retrieval_qa
@ -114,6 +119,9 @@ def test_validate_vectorstore(
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
@ -122,6 +130,9 @@ def test_validate_vectorstore(
llm=FakeLLM(),
chain_type="stuff",
retriever=unsupported_retriever,
owner="owner",
description="description",
app_name="app_name",
)
assert (
"Vectorstore must be an instance of one of the supported vectorstores"