mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 00:49:25 +00:00
community[minor]: Enable retrieval api calls in PebbloRetrievalQA (#21958)
Description: Enable app discovery and Prompt/Response apis in PebbloSafeRetrieval Documentation: NA Unit test: N/A --------- Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
This commit is contained in:
parent
8fd231086e
commit
77ad857934
@ -3,9 +3,13 @@ Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answeri
|
||||
against a vector database.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import inspect
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
from langchain_core.callbacks import (
|
||||
@ -22,9 +26,21 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
||||
set_enforcement_filters,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.models import (
|
||||
App,
|
||||
AuthContext,
|
||||
Qa,
|
||||
SemanticContext,
|
||||
)
|
||||
from langchain_community.chains.pebblo_retrieval.utilities import (
|
||||
APP_DISCOVER_URL,
|
||||
CLASSIFIER_URL,
|
||||
PEBBLO_CLOUD_URL,
|
||||
PLUGIN_VERSION,
|
||||
PROMPT_URL,
|
||||
get_runtime,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PebbloRetrievalQA(Chain):
|
||||
@ -46,6 +62,20 @@ class PebbloRetrievalQA(Chain):
|
||||
"""Authentication context for identity enforcement."""
|
||||
semantic_context_key: str = "semantic_context" #: :meta private:
|
||||
"""Semantic context for semantic enforcement."""
|
||||
app_name: str #: :meta private:
|
||||
"""App name."""
|
||||
owner: str #: :meta private:
|
||||
"""Owner of app."""
|
||||
description: str #: :meta private:
|
||||
"""Description of app."""
|
||||
api_key: Optional[str] = None #: :meta private:
|
||||
"""Pebblo cloud API key for app."""
|
||||
classifier_url: str = CLASSIFIER_URL #: :meta private:
|
||||
"""Classifier endpoint."""
|
||||
_discover_sent = False #: :meta private:
|
||||
"""Flag to check if discover payload has been sent."""
|
||||
_prompt_sent = False #: :meta private:
|
||||
"""Flag to check if prompt payload has been sent."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@ -63,10 +93,11 @@ class PebbloRetrievalQA(Chain):
|
||||
res = indexqa({'query': 'This is my query'})
|
||||
answer, docs = res['result'], res['source_documents']
|
||||
"""
|
||||
prompt_time = datetime.datetime.now().isoformat()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
question = inputs[self.input_key]
|
||||
auth_context = inputs.get(self.auth_context_key)
|
||||
semantic_context = inputs.get(self.semantic_context_key)
|
||||
auth_context = inputs.get(self.auth_context_key, {})
|
||||
semantic_context = inputs.get(self.semantic_context_key, {})
|
||||
accepts_run_manager = (
|
||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||
)
|
||||
@ -80,6 +111,32 @@ class PebbloRetrievalQA(Chain):
|
||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||
)
|
||||
|
||||
qa = {
|
||||
"name": self.app_name,
|
||||
"context": [
|
||||
{
|
||||
"retrieved_from": doc.metadata.get("source"),
|
||||
"doc": doc.page_content,
|
||||
"vector_db": self.retriever.vectorstore.__class__.__name__,
|
||||
}
|
||||
for doc in docs
|
||||
if isinstance(doc, Document)
|
||||
],
|
||||
"prompt": {"data": question},
|
||||
"response": {
|
||||
"data": answer,
|
||||
},
|
||||
"prompt_time": prompt_time,
|
||||
"user": auth_context.user_id if auth_context else "unknown",
|
||||
"user_identities": auth_context.user_auth
|
||||
if "user_auth" in dict(auth_context)
|
||||
else []
|
||||
if auth_context
|
||||
else [],
|
||||
}
|
||||
qa_payload = Qa(**qa)
|
||||
self._send_prompt(qa_payload)
|
||||
|
||||
if self.return_source_documents:
|
||||
return {self.output_key: answer, "source_documents": docs}
|
||||
else:
|
||||
@ -158,8 +215,13 @@ class PebbloRetrievalQA(Chain):
|
||||
def from_chain_type(
|
||||
cls,
|
||||
llm: BaseLanguageModel,
|
||||
app_name: str,
|
||||
description: str,
|
||||
owner: str,
|
||||
chain_type: str = "stuff",
|
||||
chain_type_kwargs: Optional[dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
classifier_url: str = CLASSIFIER_URL,
|
||||
**kwargs: Any,
|
||||
) -> "PebbloRetrievalQA":
|
||||
"""Load chain from chain type."""
|
||||
@ -169,7 +231,29 @@ class PebbloRetrievalQA(Chain):
|
||||
combine_documents_chain = load_qa_chain(
|
||||
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||
)
|
||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
||||
|
||||
# generate app
|
||||
app = PebbloRetrievalQA._get_app_details(
|
||||
app_name=app_name,
|
||||
description=description,
|
||||
owner=owner,
|
||||
llm=llm,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
PebbloRetrievalQA._send_discover(
|
||||
app, api_key=api_key, classifier_url=classifier_url
|
||||
)
|
||||
|
||||
return cls(
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
app_name=app_name,
|
||||
owner=owner,
|
||||
description=description,
|
||||
api_key=api_key,
|
||||
classifier_url=classifier_url,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@validator("retriever", pre=True, always=True)
|
||||
def validate_vectorstore(
|
||||
@ -216,3 +300,182 @@ class PebbloRetrievalQA(Chain):
|
||||
return await self.retriever.aget_relevant_documents(
|
||||
question, callbacks=run_manager.get_child()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
|
||||
"""Fetch app details. Internal method.
|
||||
Returns:
|
||||
App: App details.
|
||||
"""
|
||||
framework, runtime = get_runtime()
|
||||
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
|
||||
app = App(
|
||||
name=app_name,
|
||||
owner=owner,
|
||||
description=description,
|
||||
runtime=runtime,
|
||||
framework=framework,
|
||||
chains=chains,
|
||||
plugin_version=PLUGIN_VERSION,
|
||||
)
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
|
||||
"""Send app discovery payload to pebblo-server. Internal method."""
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = app.dict(exclude_unset=True)
|
||||
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
logger.debug("discover-payload: %s", payload)
|
||||
logger.debug(
|
||||
"send_discover[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloRetrievalQA.set_discover_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
if api_key:
|
||||
try:
|
||||
headers.update({"x-api-key": api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_discover[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: cloud %s", e)
|
||||
|
||||
@classmethod
|
||||
def set_discover_sent(cls) -> None:
|
||||
cls._discover_sent = True
|
||||
|
||||
@classmethod
|
||||
def set_prompt_sent(cls) -> None:
|
||||
cls._prompt_sent = True
|
||||
|
||||
def _send_prompt(self, qa_payload: Qa) -> None:
|
||||
headers = {
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
|
||||
try:
|
||||
pebblo_resp = requests.post(
|
||||
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
|
||||
)
|
||||
logger.debug("prompt-payload: %s", qa_payload)
|
||||
logger.debug(
|
||||
"send_prompt[local]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_resp.request.url,
|
||||
str(pebblo_resp.request.body),
|
||||
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||
str(pebblo_resp.status_code),
|
||||
pebblo_resp.json(),
|
||||
)
|
||||
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||
PebbloRetrievalQA.set_prompt_sent()
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach pebblo server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||
|
||||
if self.api_key:
|
||||
try:
|
||||
headers.update({"x-api-key": self.api_key})
|
||||
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
|
||||
pebblo_cloud_response = requests.post(
|
||||
pebblo_cloud_url,
|
||||
headers=headers,
|
||||
json=qa_payload.dict(),
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"send_prompt[cloud]: request url %s, body %s len %s\
|
||||
response status %s body %s",
|
||||
pebblo_cloud_response.request.url,
|
||||
str(pebblo_cloud_response.request.body),
|
||||
str(
|
||||
len(
|
||||
pebblo_cloud_response.request.body
|
||||
if pebblo_cloud_response.request.body
|
||||
else []
|
||||
)
|
||||
),
|
||||
str(pebblo_cloud_response.status_code),
|
||||
pebblo_cloud_response.json(),
|
||||
)
|
||||
except requests.exceptions.RequestException:
|
||||
logger.warning("Unable to reach Pebblo cloud server.")
|
||||
except Exception as e:
|
||||
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
|
||||
|
||||
@classmethod
|
||||
def get_chain_details(cls, llm, **kwargs): # type: ignore
|
||||
llm_dict = llm.__dict__
|
||||
chain = [
|
||||
{
|
||||
"name": cls.__name__,
|
||||
"model": {
|
||||
"name": llm_dict.get("model_name", llm_dict.get("model")),
|
||||
"vendor": llm.__class__.__name__,
|
||||
},
|
||||
"vector_dbs": [
|
||||
{
|
||||
"name": kwargs["retriever"].vectorstore.__class__.__name__,
|
||||
"embedding_model": str(
|
||||
kwargs["retriever"].vectorstore._embeddings.model
|
||||
)
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
|
||||
else (
|
||||
str(kwargs["retriever"].vectorstore._embedding.model)
|
||||
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
|
||||
else None
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
]
|
||||
return chain
|
||||
|
@ -1,13 +1,13 @@
|
||||
"""
|
||||
Identity & Semantic Enforcement filters for PebbloRetrievalQA chain:
|
||||
|
||||
This module contains methods for applying Identity and Semantic Enforcement filters
|
||||
in the PebbloRetrievalQA chain.
|
||||
These filters are used to control the retrieval of documents based on authorization and
|
||||
semantic context.
|
||||
The Identity Enforcement filter ensures that only authorized identities can access
|
||||
certain documents, while the Semantic Enforcement filter controls document retrieval
|
||||
based on semantic context.
|
||||
This module contains methods for applying Identity and Semantic Enforcement filters
|
||||
in the PebbloRetrievalQA chain.
|
||||
These filters are used to control the retrieval of documents based on authorization and
|
||||
semantic context.
|
||||
The Identity Enforcement filter ensures that only authorized identities can access
|
||||
certain documents, while the Semantic Enforcement filter controls document retrieval
|
||||
based on semantic context.
|
||||
|
||||
The methods in this module are designed to work with different types of vector stores.
|
||||
"""
|
||||
|
@ -60,3 +60,86 @@ class ChainInput(BaseModel):
|
||||
base_dict["auth_context"] = self.auth_context
|
||||
base_dict["semantic_context"] = self.semantic_context
|
||||
return base_dict
|
||||
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
OS, language details
|
||||
"""
|
||||
|
||||
type: Optional[str] = ""
|
||||
host: str
|
||||
path: str
|
||||
ip: Optional[str] = ""
|
||||
platform: str
|
||||
os: str
|
||||
os_version: str
|
||||
language: str
|
||||
language_version: str
|
||||
runtime: Optional[str] = ""
|
||||
|
||||
|
||||
class Framework(BaseModel):
|
||||
"""
|
||||
Langchain framework details
|
||||
"""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
vendor: Optional[str]
|
||||
name: Optional[str]
|
||||
|
||||
|
||||
class PkgInfo(BaseModel):
|
||||
project_home_page: Optional[str]
|
||||
documentation_url: Optional[str]
|
||||
pypi_url: Optional[str]
|
||||
liscence_type: Optional[str]
|
||||
installed_via: Optional[str]
|
||||
location: Optional[str]
|
||||
|
||||
|
||||
class VectorDB(BaseModel):
|
||||
name: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
embedding_model: Optional[str] = None
|
||||
|
||||
|
||||
class Chains(BaseModel):
|
||||
name: str
|
||||
model: Optional[Model]
|
||||
vector_dbs: Optional[List[VectorDB]]
|
||||
|
||||
|
||||
class App(BaseModel):
|
||||
name: str
|
||||
owner: str
|
||||
description: Optional[str]
|
||||
runtime: Runtime
|
||||
framework: Framework
|
||||
chains: List[Chains]
|
||||
plugin_version: str
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
retrieved_from: Optional[str]
|
||||
doc: Optional[str]
|
||||
vector_db: str
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
data: str
|
||||
|
||||
|
||||
class Qa(BaseModel):
|
||||
name: str
|
||||
context: List[Optional[Context]]
|
||||
prompt: Prompt
|
||||
response: Prompt
|
||||
prompt_time: str
|
||||
user: str
|
||||
user_identities: Optional[List[str]]
|
||||
|
@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from typing import Tuple
|
||||
|
||||
from langchain_core.env import get_runtime_environment
|
||||
|
||||
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGIN_VERSION = "0.1.1"
|
||||
|
||||
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
||||
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||
|
||||
PROMPT_URL = "/v1/prompt"
|
||||
APP_DISCOVER_URL = "/v1/app/discover"
|
||||
|
||||
|
||||
def get_runtime() -> Tuple[Framework, Runtime]:
|
||||
"""Fetch the current Framework and Runtime details.
|
||||
|
||||
Returns:
|
||||
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
|
||||
"""
|
||||
runtime_env = get_runtime_environment()
|
||||
framework = Framework(
|
||||
name="langchain", version=runtime_env.get("library_version", None)
|
||||
)
|
||||
uname = platform.uname()
|
||||
runtime = Runtime(
|
||||
host=uname.node,
|
||||
path=os.environ["PWD"],
|
||||
platform=runtime_env.get("platform", "unknown"),
|
||||
os=uname.system,
|
||||
os_version=uname.version,
|
||||
ip=get_ip(),
|
||||
language=runtime_env.get("runtime", "unknown"),
|
||||
language_version=runtime_env.get("runtime_version", "unknown"),
|
||||
)
|
||||
|
||||
if "Darwin" in runtime.os:
|
||||
runtime.type = "desktop"
|
||||
runtime.runtime = "Mac OSX"
|
||||
|
||||
logger.debug(f"framework {framework}")
|
||||
logger.debug(f"runtime {runtime}")
|
||||
return framework, runtime
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
"""Fetch local runtime ip address.
|
||||
|
||||
Returns:
|
||||
str: IP address
|
||||
"""
|
||||
import socket # lazy imports
|
||||
|
||||
host = socket.gethostname()
|
||||
try:
|
||||
public_ip = socket.gethostbyname(host)
|
||||
except Exception:
|
||||
public_ip = socket.gethostbyname("localhost")
|
||||
return public_ip
|
@ -72,7 +72,12 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
|
||||
Create a PebbloRetrievalQA instance
|
||||
"""
|
||||
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
||||
llm=FakeLLM(), chain_type="stuff", retriever=retriever
|
||||
llm=FakeLLM(),
|
||||
chain_type="stuff",
|
||||
retriever=retriever,
|
||||
owner="owner",
|
||||
description="description",
|
||||
app_name="app_name",
|
||||
)
|
||||
|
||||
return pebblo_retrieval_qa
|
||||
@ -114,6 +119,9 @@ def test_validate_vectorstore(
|
||||
llm=FakeLLM(),
|
||||
chain_type="stuff",
|
||||
retriever=retriever,
|
||||
owner="owner",
|
||||
description="description",
|
||||
app_name="app_name",
|
||||
)
|
||||
|
||||
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
||||
@ -122,6 +130,9 @@ def test_validate_vectorstore(
|
||||
llm=FakeLLM(),
|
||||
chain_type="stuff",
|
||||
retriever=unsupported_retriever,
|
||||
owner="owner",
|
||||
description="description",
|
||||
app_name="app_name",
|
||||
)
|
||||
assert (
|
||||
"Vectorstore must be an instance of one of the supported vectorstores"
|
||||
|
Loading…
Reference in New Issue
Block a user