diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/base.py b/libs/community/langchain_community/chains/pebblo_retrieval/base.py index a87db20281d..e5ecb0f5254 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/base.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/base.py @@ -3,9 +3,13 @@ Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answeri against a vector database. """ +import datetime import inspect +import logging +from http import HTTPStatus from typing import Any, Dict, List, Optional +import requests # type: ignore from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain_core.callbacks import ( @@ -22,9 +26,21 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import ( set_enforcement_filters, ) from langchain_community.chains.pebblo_retrieval.models import ( + App, AuthContext, + Qa, SemanticContext, ) +from langchain_community.chains.pebblo_retrieval.utilities import ( + APP_DISCOVER_URL, + CLASSIFIER_URL, + PEBBLO_CLOUD_URL, + PLUGIN_VERSION, + PROMPT_URL, + get_runtime, +) + +logger = logging.getLogger(__name__) class PebbloRetrievalQA(Chain): @@ -46,6 +62,20 @@ class PebbloRetrievalQA(Chain): """Authentication context for identity enforcement.""" semantic_context_key: str = "semantic_context" #: :meta private: """Semantic context for semantic enforcement.""" + app_name: str #: :meta private: + """App name.""" + owner: str #: :meta private: + """Owner of app.""" + description: str #: :meta private: + """Description of app.""" + api_key: Optional[str] = None #: :meta private: + """Pebblo cloud API key for app.""" + classifier_url: str = CLASSIFIER_URL #: :meta private: + """Classifier endpoint.""" + _discover_sent = False #: :meta private: + """Flag to check if discover payload has been sent.""" + _prompt_sent = False #: :meta private: + """Flag to check if prompt payload has been sent.""" def _call( self, @@ -63,10 +93,11 @@ class PebbloRetrievalQA(Chain): res = indexqa({'query': 'This is my query'}) answer, docs = res['result'], res['source_documents'] """ + prompt_time = datetime.datetime.now().isoformat() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() question = inputs[self.input_key] - auth_context = inputs.get(self.auth_context_key) - semantic_context = inputs.get(self.semantic_context_key) + auth_context = inputs.get(self.auth_context_key, {}) + semantic_context = inputs.get(self.semantic_context_key, {}) accepts_run_manager = ( "run_manager" in inspect.signature(self._get_docs).parameters ) @@ -80,6 +111,32 @@ class PebbloRetrievalQA(Chain): input_documents=docs, question=question, callbacks=_run_manager.get_child() ) + qa = { + "name": self.app_name, + "context": [ + { + "retrieved_from": doc.metadata.get("source"), + "doc": doc.page_content, + "vector_db": self.retriever.vectorstore.__class__.__name__, + } + for doc in docs + if isinstance(doc, Document) + ], + "prompt": {"data": question}, + "response": { + "data": answer, + }, + "prompt_time": prompt_time, + "user": auth_context.user_id if auth_context else "unknown", + "user_identities": auth_context.user_auth + if "user_auth" in dict(auth_context) + else [] + if auth_context + else [], + } + qa_payload = Qa(**qa) + self._send_prompt(qa_payload) + if self.return_source_documents: return {self.output_key: answer, "source_documents": docs} else: @@ -158,8 +215,13 @@ class PebbloRetrievalQA(Chain): def from_chain_type( cls, llm: BaseLanguageModel, + app_name: str, + description: str, + owner: str, chain_type: str = "stuff", chain_type_kwargs: Optional[dict] = None, + api_key: Optional[str] = None, + classifier_url: str = CLASSIFIER_URL, **kwargs: Any, ) -> "PebbloRetrievalQA": """Load chain from chain type.""" @@ -169,7 +231,29 @@ class PebbloRetrievalQA(Chain): combine_documents_chain = load_qa_chain( llm, chain_type=chain_type, **_chain_type_kwargs ) - return cls(combine_documents_chain=combine_documents_chain, **kwargs) + + # generate app + app = PebbloRetrievalQA._get_app_details( + app_name=app_name, + description=description, + owner=owner, + llm=llm, + **kwargs, + ) + + PebbloRetrievalQA._send_discover( + app, api_key=api_key, classifier_url=classifier_url + ) + + return cls( + combine_documents_chain=combine_documents_chain, + app_name=app_name, + owner=owner, + description=description, + api_key=api_key, + classifier_url=classifier_url, + **kwargs, + ) @validator("retriever", pre=True, always=True) def validate_vectorstore( @@ -216,3 +300,182 @@ class PebbloRetrievalQA(Chain): return await self.retriever.aget_relevant_documents( question, callbacks=run_manager.get_child() ) + + @staticmethod + def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore + """Fetch app details. Internal method. + Returns: + App: App details. + """ + framework, runtime = get_runtime() + chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs) + app = App( + name=app_name, + owner=owner, + description=description, + runtime=runtime, + framework=framework, + chains=chains, + plugin_version=PLUGIN_VERSION, + ) + return app + + @staticmethod + def _send_discover(app, api_key, classifier_url) -> None: # type: ignore + """Send app discovery payload to pebblo-server. Internal method.""" + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + payload = app.dict(exclude_unset=True) + app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}" + try: + pebblo_resp = requests.post( + app_discover_url, headers=headers, json=payload, timeout=20 + ) + logger.debug("discover-payload: %s", payload) + logger.debug( + "send_discover[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])), + str(pebblo_resp.status_code), + pebblo_resp.json(), + ) + if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: + PebbloRetrievalQA.set_discover_sent() + else: + logger.warning( + f"Received unexpected HTTP response code: {pebblo_resp.status_code}" + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) + + if api_key: + try: + headers.update({"x-api-key": api_key}) + pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}" + pebblo_cloud_response = requests.post( + pebblo_cloud_url, headers=headers, json=payload, timeout=20 + ) + + logger.debug( + "send_discover[cloud]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_cloud_response.request.url, + str(pebblo_cloud_response.request.body), + str( + len( + pebblo_cloud_response.request.body + if pebblo_cloud_response.request.body + else [] + ) + ), + str(pebblo_cloud_response.status_code), + pebblo_cloud_response.json(), + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach Pebblo cloud server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: cloud %s", e) + + @classmethod + def set_discover_sent(cls) -> None: + cls._discover_sent = True + + @classmethod + def set_prompt_sent(cls) -> None: + cls._prompt_sent = True + + def _send_prompt(self, qa_payload: Qa) -> None: + headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + app_discover_url = f"{self.classifier_url}{PROMPT_URL}" + try: + pebblo_resp = requests.post( + app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20 + ) + logger.debug("prompt-payload: %s", qa_payload) + logger.debug( + "send_prompt[local]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_resp.request.url, + str(pebblo_resp.request.body), + str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])), + str(pebblo_resp.status_code), + pebblo_resp.json(), + ) + if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]: + PebbloRetrievalQA.set_prompt_sent() + else: + logger.warning( + f"Received unexpected HTTP response code: {pebblo_resp.status_code}" + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach pebblo server.") + except Exception as e: + logger.warning("An Exception caught in _send_discover: local %s", e) + + if self.api_key: + try: + headers.update({"x-api-key": self.api_key}) + pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}" + pebblo_cloud_response = requests.post( + pebblo_cloud_url, + headers=headers, + json=qa_payload.dict(), + timeout=20, + ) + + logger.debug( + "send_prompt[cloud]: request url %s, body %s len %s\ + response status %s body %s", + pebblo_cloud_response.request.url, + str(pebblo_cloud_response.request.body), + str( + len( + pebblo_cloud_response.request.body + if pebblo_cloud_response.request.body + else [] + ) + ), + str(pebblo_cloud_response.status_code), + pebblo_cloud_response.json(), + ) + except requests.exceptions.RequestException: + logger.warning("Unable to reach Pebblo cloud server.") + except Exception as e: + logger.warning("An Exception caught in _send_prompt: cloud %s", e) + + @classmethod + def get_chain_details(cls, llm, **kwargs): # type: ignore + llm_dict = llm.__dict__ + chain = [ + { + "name": cls.__name__, + "model": { + "name": llm_dict.get("model_name", llm_dict.get("model")), + "vendor": llm.__class__.__name__, + }, + "vector_dbs": [ + { + "name": kwargs["retriever"].vectorstore.__class__.__name__, + "embedding_model": str( + kwargs["retriever"].vectorstore._embeddings.model + ) + if hasattr(kwargs["retriever"].vectorstore, "_embeddings") + else ( + str(kwargs["retriever"].vectorstore._embedding.model) + if hasattr(kwargs["retriever"].vectorstore, "_embedding") + else None + ), + } + ], + } + ] + return chain diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py index b4268b5c667..1761acb09fb 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/enforcement_filters.py @@ -1,13 +1,13 @@ """ Identity & Semantic Enforcement filters for PebbloRetrievalQA chain: -This module contains methods for applying Identity and Semantic Enforcement filters -in the PebbloRetrievalQA chain. -These filters are used to control the retrieval of documents based on authorization and -semantic context. -The Identity Enforcement filter ensures that only authorized identities can access -certain documents, while the Semantic Enforcement filter controls document retrieval -based on semantic context. +This module contains methods for applying Identity and Semantic Enforcement filters +in the PebbloRetrievalQA chain. +These filters are used to control the retrieval of documents based on authorization and +semantic context. +The Identity Enforcement filter ensures that only authorized identities can access +certain documents, while the Semantic Enforcement filter controls document retrieval +based on semantic context. The methods in this module are designed to work with different types of vector stores. """ diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/models.py b/libs/community/langchain_community/chains/pebblo_retrieval/models.py index 73ba1592877..3dc344dd38c 100644 --- a/libs/community/langchain_community/chains/pebblo_retrieval/models.py +++ b/libs/community/langchain_community/chains/pebblo_retrieval/models.py @@ -60,3 +60,86 @@ class ChainInput(BaseModel): base_dict["auth_context"] = self.auth_context base_dict["semantic_context"] = self.semantic_context return base_dict + + +class Runtime(BaseModel): + """ + OS, language details + """ + + type: Optional[str] = "" + host: str + path: str + ip: Optional[str] = "" + platform: str + os: str + os_version: str + language: str + language_version: str + runtime: Optional[str] = "" + + +class Framework(BaseModel): + """ + Langchain framework details + """ + + name: str + version: str + + +class Model(BaseModel): + vendor: Optional[str] + name: Optional[str] + + +class PkgInfo(BaseModel): + project_home_page: Optional[str] + documentation_url: Optional[str] + pypi_url: Optional[str] + liscence_type: Optional[str] + installed_via: Optional[str] + location: Optional[str] + + +class VectorDB(BaseModel): + name: Optional[str] = None + version: Optional[str] = None + location: Optional[str] = None + embedding_model: Optional[str] = None + + +class Chains(BaseModel): + name: str + model: Optional[Model] + vector_dbs: Optional[List[VectorDB]] + + +class App(BaseModel): + name: str + owner: str + description: Optional[str] + runtime: Runtime + framework: Framework + chains: List[Chains] + plugin_version: str + + +class Context(BaseModel): + retrieved_from: Optional[str] + doc: Optional[str] + vector_db: str + + +class Prompt(BaseModel): + data: str + + +class Qa(BaseModel): + name: str + context: List[Optional[Context]] + prompt: Prompt + response: Prompt + prompt_time: str + user: str + user_identities: Optional[List[str]] diff --git a/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py new file mode 100644 index 00000000000..3056c8fae7c --- /dev/null +++ b/libs/community/langchain_community/chains/pebblo_retrieval/utilities.py @@ -0,0 +1,65 @@ +import logging +import os +import platform +from typing import Tuple + +from langchain_core.env import get_runtime_environment + +from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime + +logger = logging.getLogger(__name__) + +PLUGIN_VERSION = "0.1.1" + +CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000") +PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai") + +PROMPT_URL = "/v1/prompt" +APP_DISCOVER_URL = "/v1/app/discover" + + +def get_runtime() -> Tuple[Framework, Runtime]: + """Fetch the current Framework and Runtime details. + + Returns: + Tuple[Framework, Runtime]: Framework and Runtime for the current app instance. + """ + runtime_env = get_runtime_environment() + framework = Framework( + name="langchain", version=runtime_env.get("library_version", None) + ) + uname = platform.uname() + runtime = Runtime( + host=uname.node, + path=os.environ["PWD"], + platform=runtime_env.get("platform", "unknown"), + os=uname.system, + os_version=uname.version, + ip=get_ip(), + language=runtime_env.get("runtime", "unknown"), + language_version=runtime_env.get("runtime_version", "unknown"), + ) + + if "Darwin" in runtime.os: + runtime.type = "desktop" + runtime.runtime = "Mac OSX" + + logger.debug(f"framework {framework}") + logger.debug(f"runtime {runtime}") + return framework, runtime + + +def get_ip() -> str: + """Fetch local runtime ip address. + + Returns: + str: IP address + """ + import socket # lazy imports + + host = socket.gethostname() + try: + public_ip = socket.gethostbyname(host) + except Exception: + public_ip = socket.gethostbyname("localhost") + return public_ip diff --git a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py index d0404a57b66..155b1e85dce 100644 --- a/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py +++ b/libs/community/tests/unit_tests/chains/test_pebblo_retrieval.py @@ -72,7 +72,12 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA: Create a PebbloRetrievalQA instance """ pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type( - llm=FakeLLM(), chain_type="stuff", retriever=retriever + llm=FakeLLM(), + chain_type="stuff", + retriever=retriever, + owner="owner", + description="description", + app_name="app_name", ) return pebblo_retrieval_qa @@ -114,6 +119,9 @@ def test_validate_vectorstore( llm=FakeLLM(), chain_type="stuff", retriever=retriever, + owner="owner", + description="description", + app_name="app_name", ) # validate_vectorstore method should raise a ValueError for unsupported vectorstores @@ -122,6 +130,9 @@ def test_validate_vectorstore( llm=FakeLLM(), chain_type="stuff", retriever=unsupported_retriever, + owner="owner", + description="description", + app_name="app_name", ) assert ( "Vectorstore must be an instance of one of the supported vectorstores"