community[minor]: Enable retrieval api calls in PebbloRetrievalQA (#21958)

Description: Enable app discovery and Prompt/Response apis in
PebbloSafeRetrieval
Documentation: NA
Unit test: N/A

---------

Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
This commit is contained in:
Rahul Triptahi 2024-06-04 22:48:50 +05:30 committed by GitHub
parent 8fd231086e
commit 77ad857934
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 433 additions and 11 deletions

View File

@ -3,9 +3,13 @@ Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answeri
against a vector database.
"""
import datetime
import inspect
import logging
from http import HTTPStatus
from typing import Any, Dict, List, Optional
import requests # type: ignore
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain_core.callbacks import (
@ -22,9 +26,21 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
set_enforcement_filters,
)
from langchain_community.chains.pebblo_retrieval.models import (
App,
AuthContext,
Qa,
SemanticContext,
)
from langchain_community.chains.pebblo_retrieval.utilities import (
APP_DISCOVER_URL,
CLASSIFIER_URL,
PEBBLO_CLOUD_URL,
PLUGIN_VERSION,
PROMPT_URL,
get_runtime,
)
logger = logging.getLogger(__name__)
class PebbloRetrievalQA(Chain):
@ -46,6 +62,20 @@ class PebbloRetrievalQA(Chain):
"""Authentication context for identity enforcement."""
semantic_context_key: str = "semantic_context" #: :meta private:
"""Semantic context for semantic enforcement."""
app_name: str #: :meta private:
"""App name."""
owner: str #: :meta private:
"""Owner of app."""
description: str #: :meta private:
"""Description of app."""
api_key: Optional[str] = None #: :meta private:
"""Pebblo cloud API key for app."""
classifier_url: str = CLASSIFIER_URL #: :meta private:
"""Classifier endpoint."""
_discover_sent = False #: :meta private:
"""Flag to check if discover payload has been sent."""
_prompt_sent = False #: :meta private:
"""Flag to check if prompt payload has been sent."""
def _call(
self,
@ -63,10 +93,11 @@ class PebbloRetrievalQA(Chain):
res = indexqa({'query': 'This is my query'})
answer, docs = res['result'], res['source_documents']
"""
prompt_time = datetime.datetime.now().isoformat()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
question = inputs[self.input_key]
auth_context = inputs.get(self.auth_context_key)
semantic_context = inputs.get(self.semantic_context_key)
auth_context = inputs.get(self.auth_context_key, {})
semantic_context = inputs.get(self.semantic_context_key, {})
accepts_run_manager = (
"run_manager" in inspect.signature(self._get_docs).parameters
)
@ -80,6 +111,32 @@ class PebbloRetrievalQA(Chain):
input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
qa = {
"name": self.app_name,
"context": [
{
"retrieved_from": doc.metadata.get("source"),
"doc": doc.page_content,
"vector_db": self.retriever.vectorstore.__class__.__name__,
}
for doc in docs
if isinstance(doc, Document)
],
"prompt": {"data": question},
"response": {
"data": answer,
},
"prompt_time": prompt_time,
"user": auth_context.user_id if auth_context else "unknown",
"user_identities": auth_context.user_auth
if "user_auth" in dict(auth_context)
else []
if auth_context
else [],
}
qa_payload = Qa(**qa)
self._send_prompt(qa_payload)
if self.return_source_documents:
return {self.output_key: answer, "source_documents": docs}
else:
@ -158,8 +215,13 @@ class PebbloRetrievalQA(Chain):
def from_chain_type(
cls,
llm: BaseLanguageModel,
app_name: str,
description: str,
owner: str,
chain_type: str = "stuff",
chain_type_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
classifier_url: str = CLASSIFIER_URL,
**kwargs: Any,
) -> "PebbloRetrievalQA":
"""Load chain from chain type."""
@ -169,7 +231,29 @@ class PebbloRetrievalQA(Chain):
combine_documents_chain = load_qa_chain(
llm, chain_type=chain_type, **_chain_type_kwargs
)
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
# generate app
app = PebbloRetrievalQA._get_app_details(
app_name=app_name,
description=description,
owner=owner,
llm=llm,
**kwargs,
)
PebbloRetrievalQA._send_discover(
app, api_key=api_key, classifier_url=classifier_url
)
return cls(
combine_documents_chain=combine_documents_chain,
app_name=app_name,
owner=owner,
description=description,
api_key=api_key,
classifier_url=classifier_url,
**kwargs,
)
@validator("retriever", pre=True, always=True)
def validate_vectorstore(
@ -216,3 +300,182 @@ class PebbloRetrievalQA(Chain):
return await self.retriever.aget_relevant_documents(
question, callbacks=run_manager.get_child()
)
@staticmethod
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
"""Fetch app details. Internal method.
Returns:
App: App details.
"""
framework, runtime = get_runtime()
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
app = App(
name=app_name,
owner=owner,
description=description,
runtime=runtime,
framework=framework,
chains=chains,
plugin_version=PLUGIN_VERSION,
)
return app
@staticmethod
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
"""Send app discovery payload to pebblo-server. Internal method."""
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
payload = app.dict(exclude_unset=True)
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=payload, timeout=20
)
logger.debug("discover-payload: %s", payload)
logger.debug(
"send_discover[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_discover_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if api_key:
try:
headers.update({"x-api-key": api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url, headers=headers, json=payload, timeout=20
)
logger.debug(
"send_discover[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: cloud %s", e)
@classmethod
def set_discover_sent(cls) -> None:
cls._discover_sent = True
@classmethod
def set_prompt_sent(cls) -> None:
cls._prompt_sent = True
def _send_prompt(self, qa_payload: Qa) -> None:
headers = {
"Accept": "application/json",
"Content-Type": "application/json",
}
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
try:
pebblo_resp = requests.post(
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
)
logger.debug("prompt-payload: %s", qa_payload)
logger.debug(
"send_prompt[local]: request url %s, body %s len %s\
response status %s body %s",
pebblo_resp.request.url,
str(pebblo_resp.request.body),
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
str(pebblo_resp.status_code),
pebblo_resp.json(),
)
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
PebbloRetrievalQA.set_prompt_sent()
else:
logger.warning(
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach pebblo server.")
except Exception as e:
logger.warning("An Exception caught in _send_discover: local %s", e)
if self.api_key:
try:
headers.update({"x-api-key": self.api_key})
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
pebblo_cloud_response = requests.post(
pebblo_cloud_url,
headers=headers,
json=qa_payload.dict(),
timeout=20,
)
logger.debug(
"send_prompt[cloud]: request url %s, body %s len %s\
response status %s body %s",
pebblo_cloud_response.request.url,
str(pebblo_cloud_response.request.body),
str(
len(
pebblo_cloud_response.request.body
if pebblo_cloud_response.request.body
else []
)
),
str(pebblo_cloud_response.status_code),
pebblo_cloud_response.json(),
)
except requests.exceptions.RequestException:
logger.warning("Unable to reach Pebblo cloud server.")
except Exception as e:
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
@classmethod
def get_chain_details(cls, llm, **kwargs): # type: ignore
llm_dict = llm.__dict__
chain = [
{
"name": cls.__name__,
"model": {
"name": llm_dict.get("model_name", llm_dict.get("model")),
"vendor": llm.__class__.__name__,
},
"vector_dbs": [
{
"name": kwargs["retriever"].vectorstore.__class__.__name__,
"embedding_model": str(
kwargs["retriever"].vectorstore._embeddings.model
)
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
else (
str(kwargs["retriever"].vectorstore._embedding.model)
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
else None
),
}
],
}
]
return chain

View File

@ -60,3 +60,86 @@ class ChainInput(BaseModel):
base_dict["auth_context"] = self.auth_context
base_dict["semantic_context"] = self.semantic_context
return base_dict
class Runtime(BaseModel):
"""
OS, language details
"""
type: Optional[str] = ""
host: str
path: str
ip: Optional[str] = ""
platform: str
os: str
os_version: str
language: str
language_version: str
runtime: Optional[str] = ""
class Framework(BaseModel):
"""
Langchain framework details
"""
name: str
version: str
class Model(BaseModel):
vendor: Optional[str]
name: Optional[str]
class PkgInfo(BaseModel):
project_home_page: Optional[str]
documentation_url: Optional[str]
pypi_url: Optional[str]
liscence_type: Optional[str]
installed_via: Optional[str]
location: Optional[str]
class VectorDB(BaseModel):
name: Optional[str] = None
version: Optional[str] = None
location: Optional[str] = None
embedding_model: Optional[str] = None
class Chains(BaseModel):
name: str
model: Optional[Model]
vector_dbs: Optional[List[VectorDB]]
class App(BaseModel):
name: str
owner: str
description: Optional[str]
runtime: Runtime
framework: Framework
chains: List[Chains]
plugin_version: str
class Context(BaseModel):
retrieved_from: Optional[str]
doc: Optional[str]
vector_db: str
class Prompt(BaseModel):
data: str
class Qa(BaseModel):
name: str
context: List[Optional[Context]]
prompt: Prompt
response: Prompt
prompt_time: str
user: str
user_identities: Optional[List[str]]

View File

@ -0,0 +1,65 @@
import logging
import os
import platform
from typing import Tuple
from langchain_core.env import get_runtime_environment
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
logger = logging.getLogger(__name__)
PLUGIN_VERSION = "0.1.1"
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
PROMPT_URL = "/v1/prompt"
APP_DISCOVER_URL = "/v1/app/discover"
def get_runtime() -> Tuple[Framework, Runtime]:
"""Fetch the current Framework and Runtime details.
Returns:
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
"""
runtime_env = get_runtime_environment()
framework = Framework(
name="langchain", version=runtime_env.get("library_version", None)
)
uname = platform.uname()
runtime = Runtime(
host=uname.node,
path=os.environ["PWD"],
platform=runtime_env.get("platform", "unknown"),
os=uname.system,
os_version=uname.version,
ip=get_ip(),
language=runtime_env.get("runtime", "unknown"),
language_version=runtime_env.get("runtime_version", "unknown"),
)
if "Darwin" in runtime.os:
runtime.type = "desktop"
runtime.runtime = "Mac OSX"
logger.debug(f"framework {framework}")
logger.debug(f"runtime {runtime}")
return framework, runtime
def get_ip() -> str:
"""Fetch local runtime ip address.
Returns:
str: IP address
"""
import socket # lazy imports
host = socket.gethostname()
try:
public_ip = socket.gethostbyname(host)
except Exception:
public_ip = socket.gethostbyname("localhost")
return public_ip

View File

@ -72,7 +72,12 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
Create a PebbloRetrievalQA instance
"""
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
llm=FakeLLM(), chain_type="stuff", retriever=retriever
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
return pebblo_retrieval_qa
@ -114,6 +119,9 @@ def test_validate_vectorstore(
llm=FakeLLM(),
chain_type="stuff",
retriever=retriever,
owner="owner",
description="description",
app_name="app_name",
)
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
@ -122,6 +130,9 @@ def test_validate_vectorstore(
llm=FakeLLM(),
chain_type="stuff",
retriever=unsupported_retriever,
owner="owner",
description="description",
app_name="app_name",
)
assert (
"Vectorstore must be an instance of one of the supported vectorstores"