mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-09 21:08:59 +00:00
community[minor]: Enable retrieval api calls in PebbloRetrievalQA (#21958)
Description: Enable app discovery and Prompt/Response apis in PebbloSafeRetrieval Documentation: NA Unit test: N/A --------- Signed-off-by: Rahul Tripathi <rauhl.psit.ec@gmail.com> Co-authored-by: Rahul Tripathi <rauhl.psit.ec@gmail.com>
This commit is contained in:
parent
8fd231086e
commit
77ad857934
@ -3,9 +3,13 @@ Pebblo Retrieval Chain with Identity & Semantic Enforcement for question-answeri
|
|||||||
against a vector database.
|
against a vector database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
|
from http import HTTPStatus
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import requests # type: ignore
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
@ -22,9 +26,21 @@ from langchain_community.chains.pebblo_retrieval.enforcement_filters import (
|
|||||||
set_enforcement_filters,
|
set_enforcement_filters,
|
||||||
)
|
)
|
||||||
from langchain_community.chains.pebblo_retrieval.models import (
|
from langchain_community.chains.pebblo_retrieval.models import (
|
||||||
|
App,
|
||||||
AuthContext,
|
AuthContext,
|
||||||
|
Qa,
|
||||||
SemanticContext,
|
SemanticContext,
|
||||||
)
|
)
|
||||||
|
from langchain_community.chains.pebblo_retrieval.utilities import (
|
||||||
|
APP_DISCOVER_URL,
|
||||||
|
CLASSIFIER_URL,
|
||||||
|
PEBBLO_CLOUD_URL,
|
||||||
|
PLUGIN_VERSION,
|
||||||
|
PROMPT_URL,
|
||||||
|
get_runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PebbloRetrievalQA(Chain):
|
class PebbloRetrievalQA(Chain):
|
||||||
@ -46,6 +62,20 @@ class PebbloRetrievalQA(Chain):
|
|||||||
"""Authentication context for identity enforcement."""
|
"""Authentication context for identity enforcement."""
|
||||||
semantic_context_key: str = "semantic_context" #: :meta private:
|
semantic_context_key: str = "semantic_context" #: :meta private:
|
||||||
"""Semantic context for semantic enforcement."""
|
"""Semantic context for semantic enforcement."""
|
||||||
|
app_name: str #: :meta private:
|
||||||
|
"""App name."""
|
||||||
|
owner: str #: :meta private:
|
||||||
|
"""Owner of app."""
|
||||||
|
description: str #: :meta private:
|
||||||
|
"""Description of app."""
|
||||||
|
api_key: Optional[str] = None #: :meta private:
|
||||||
|
"""Pebblo cloud API key for app."""
|
||||||
|
classifier_url: str = CLASSIFIER_URL #: :meta private:
|
||||||
|
"""Classifier endpoint."""
|
||||||
|
_discover_sent = False #: :meta private:
|
||||||
|
"""Flag to check if discover payload has been sent."""
|
||||||
|
_prompt_sent = False #: :meta private:
|
||||||
|
"""Flag to check if prompt payload has been sent."""
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@ -63,10 +93,11 @@ class PebbloRetrievalQA(Chain):
|
|||||||
res = indexqa({'query': 'This is my query'})
|
res = indexqa({'query': 'This is my query'})
|
||||||
answer, docs = res['result'], res['source_documents']
|
answer, docs = res['result'], res['source_documents']
|
||||||
"""
|
"""
|
||||||
|
prompt_time = datetime.datetime.now().isoformat()
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||||
question = inputs[self.input_key]
|
question = inputs[self.input_key]
|
||||||
auth_context = inputs.get(self.auth_context_key)
|
auth_context = inputs.get(self.auth_context_key, {})
|
||||||
semantic_context = inputs.get(self.semantic_context_key)
|
semantic_context = inputs.get(self.semantic_context_key, {})
|
||||||
accepts_run_manager = (
|
accepts_run_manager = (
|
||||||
"run_manager" in inspect.signature(self._get_docs).parameters
|
"run_manager" in inspect.signature(self._get_docs).parameters
|
||||||
)
|
)
|
||||||
@ -80,6 +111,32 @@ class PebbloRetrievalQA(Chain):
|
|||||||
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
input_documents=docs, question=question, callbacks=_run_manager.get_child()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
qa = {
|
||||||
|
"name": self.app_name,
|
||||||
|
"context": [
|
||||||
|
{
|
||||||
|
"retrieved_from": doc.metadata.get("source"),
|
||||||
|
"doc": doc.page_content,
|
||||||
|
"vector_db": self.retriever.vectorstore.__class__.__name__,
|
||||||
|
}
|
||||||
|
for doc in docs
|
||||||
|
if isinstance(doc, Document)
|
||||||
|
],
|
||||||
|
"prompt": {"data": question},
|
||||||
|
"response": {
|
||||||
|
"data": answer,
|
||||||
|
},
|
||||||
|
"prompt_time": prompt_time,
|
||||||
|
"user": auth_context.user_id if auth_context else "unknown",
|
||||||
|
"user_identities": auth_context.user_auth
|
||||||
|
if "user_auth" in dict(auth_context)
|
||||||
|
else []
|
||||||
|
if auth_context
|
||||||
|
else [],
|
||||||
|
}
|
||||||
|
qa_payload = Qa(**qa)
|
||||||
|
self._send_prompt(qa_payload)
|
||||||
|
|
||||||
if self.return_source_documents:
|
if self.return_source_documents:
|
||||||
return {self.output_key: answer, "source_documents": docs}
|
return {self.output_key: answer, "source_documents": docs}
|
||||||
else:
|
else:
|
||||||
@ -158,8 +215,13 @@ class PebbloRetrievalQA(Chain):
|
|||||||
def from_chain_type(
|
def from_chain_type(
|
||||||
cls,
|
cls,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
|
app_name: str,
|
||||||
|
description: str,
|
||||||
|
owner: str,
|
||||||
chain_type: str = "stuff",
|
chain_type: str = "stuff",
|
||||||
chain_type_kwargs: Optional[dict] = None,
|
chain_type_kwargs: Optional[dict] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
classifier_url: str = CLASSIFIER_URL,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> "PebbloRetrievalQA":
|
) -> "PebbloRetrievalQA":
|
||||||
"""Load chain from chain type."""
|
"""Load chain from chain type."""
|
||||||
@ -169,7 +231,29 @@ class PebbloRetrievalQA(Chain):
|
|||||||
combine_documents_chain = load_qa_chain(
|
combine_documents_chain = load_qa_chain(
|
||||||
llm, chain_type=chain_type, **_chain_type_kwargs
|
llm, chain_type=chain_type, **_chain_type_kwargs
|
||||||
)
|
)
|
||||||
return cls(combine_documents_chain=combine_documents_chain, **kwargs)
|
|
||||||
|
# generate app
|
||||||
|
app = PebbloRetrievalQA._get_app_details(
|
||||||
|
app_name=app_name,
|
||||||
|
description=description,
|
||||||
|
owner=owner,
|
||||||
|
llm=llm,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
PebbloRetrievalQA._send_discover(
|
||||||
|
app, api_key=api_key, classifier_url=classifier_url
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
combine_documents_chain=combine_documents_chain,
|
||||||
|
app_name=app_name,
|
||||||
|
owner=owner,
|
||||||
|
description=description,
|
||||||
|
api_key=api_key,
|
||||||
|
classifier_url=classifier_url,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@validator("retriever", pre=True, always=True)
|
@validator("retriever", pre=True, always=True)
|
||||||
def validate_vectorstore(
|
def validate_vectorstore(
|
||||||
@ -216,3 +300,182 @@ class PebbloRetrievalQA(Chain):
|
|||||||
return await self.retriever.aget_relevant_documents(
|
return await self.retriever.aget_relevant_documents(
|
||||||
question, callbacks=run_manager.get_child()
|
question, callbacks=run_manager.get_child()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_app_details(app_name, owner, description, llm, **kwargs) -> App: # type: ignore
|
||||||
|
"""Fetch app details. Internal method.
|
||||||
|
Returns:
|
||||||
|
App: App details.
|
||||||
|
"""
|
||||||
|
framework, runtime = get_runtime()
|
||||||
|
chains = PebbloRetrievalQA.get_chain_details(llm, **kwargs)
|
||||||
|
app = App(
|
||||||
|
name=app_name,
|
||||||
|
owner=owner,
|
||||||
|
description=description,
|
||||||
|
runtime=runtime,
|
||||||
|
framework=framework,
|
||||||
|
chains=chains,
|
||||||
|
plugin_version=PLUGIN_VERSION,
|
||||||
|
)
|
||||||
|
return app
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _send_discover(app, api_key, classifier_url) -> None: # type: ignore
|
||||||
|
"""Send app discovery payload to pebblo-server. Internal method."""
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
payload = app.dict(exclude_unset=True)
|
||||||
|
app_discover_url = f"{classifier_url}{APP_DISCOVER_URL}"
|
||||||
|
try:
|
||||||
|
pebblo_resp = requests.post(
|
||||||
|
app_discover_url, headers=headers, json=payload, timeout=20
|
||||||
|
)
|
||||||
|
logger.debug("discover-payload: %s", payload)
|
||||||
|
logger.debug(
|
||||||
|
"send_discover[local]: request url %s, body %s len %s\
|
||||||
|
response status %s body %s",
|
||||||
|
pebblo_resp.request.url,
|
||||||
|
str(pebblo_resp.request.body),
|
||||||
|
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||||
|
str(pebblo_resp.status_code),
|
||||||
|
pebblo_resp.json(),
|
||||||
|
)
|
||||||
|
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||||
|
PebbloRetrievalQA.set_discover_sent()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
logger.warning("Unable to reach pebblo server.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||||
|
|
||||||
|
if api_key:
|
||||||
|
try:
|
||||||
|
headers.update({"x-api-key": api_key})
|
||||||
|
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{APP_DISCOVER_URL}"
|
||||||
|
pebblo_cloud_response = requests.post(
|
||||||
|
pebblo_cloud_url, headers=headers, json=payload, timeout=20
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"send_discover[cloud]: request url %s, body %s len %s\
|
||||||
|
response status %s body %s",
|
||||||
|
pebblo_cloud_response.request.url,
|
||||||
|
str(pebblo_cloud_response.request.body),
|
||||||
|
str(
|
||||||
|
len(
|
||||||
|
pebblo_cloud_response.request.body
|
||||||
|
if pebblo_cloud_response.request.body
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
),
|
||||||
|
str(pebblo_cloud_response.status_code),
|
||||||
|
pebblo_cloud_response.json(),
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
logger.warning("Unable to reach Pebblo cloud server.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in _send_discover: cloud %s", e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_discover_sent(cls) -> None:
|
||||||
|
cls._discover_sent = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_prompt_sent(cls) -> None:
|
||||||
|
cls._prompt_sent = True
|
||||||
|
|
||||||
|
def _send_prompt(self, qa_payload: Qa) -> None:
|
||||||
|
headers = {
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
app_discover_url = f"{self.classifier_url}{PROMPT_URL}"
|
||||||
|
try:
|
||||||
|
pebblo_resp = requests.post(
|
||||||
|
app_discover_url, headers=headers, json=qa_payload.dict(), timeout=20
|
||||||
|
)
|
||||||
|
logger.debug("prompt-payload: %s", qa_payload)
|
||||||
|
logger.debug(
|
||||||
|
"send_prompt[local]: request url %s, body %s len %s\
|
||||||
|
response status %s body %s",
|
||||||
|
pebblo_resp.request.url,
|
||||||
|
str(pebblo_resp.request.body),
|
||||||
|
str(len(pebblo_resp.request.body if pebblo_resp.request.body else [])),
|
||||||
|
str(pebblo_resp.status_code),
|
||||||
|
pebblo_resp.json(),
|
||||||
|
)
|
||||||
|
if pebblo_resp.status_code in [HTTPStatus.OK, HTTPStatus.BAD_GATEWAY]:
|
||||||
|
PebbloRetrievalQA.set_prompt_sent()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"Received unexpected HTTP response code: {pebblo_resp.status_code}"
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
logger.warning("Unable to reach pebblo server.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in _send_discover: local %s", e)
|
||||||
|
|
||||||
|
if self.api_key:
|
||||||
|
try:
|
||||||
|
headers.update({"x-api-key": self.api_key})
|
||||||
|
pebblo_cloud_url = f"{PEBBLO_CLOUD_URL}{PROMPT_URL}"
|
||||||
|
pebblo_cloud_response = requests.post(
|
||||||
|
pebblo_cloud_url,
|
||||||
|
headers=headers,
|
||||||
|
json=qa_payload.dict(),
|
||||||
|
timeout=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"send_prompt[cloud]: request url %s, body %s len %s\
|
||||||
|
response status %s body %s",
|
||||||
|
pebblo_cloud_response.request.url,
|
||||||
|
str(pebblo_cloud_response.request.body),
|
||||||
|
str(
|
||||||
|
len(
|
||||||
|
pebblo_cloud_response.request.body
|
||||||
|
if pebblo_cloud_response.request.body
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
),
|
||||||
|
str(pebblo_cloud_response.status_code),
|
||||||
|
pebblo_cloud_response.json(),
|
||||||
|
)
|
||||||
|
except requests.exceptions.RequestException:
|
||||||
|
logger.warning("Unable to reach Pebblo cloud server.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("An Exception caught in _send_prompt: cloud %s", e)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_chain_details(cls, llm, **kwargs): # type: ignore
|
||||||
|
llm_dict = llm.__dict__
|
||||||
|
chain = [
|
||||||
|
{
|
||||||
|
"name": cls.__name__,
|
||||||
|
"model": {
|
||||||
|
"name": llm_dict.get("model_name", llm_dict.get("model")),
|
||||||
|
"vendor": llm.__class__.__name__,
|
||||||
|
},
|
||||||
|
"vector_dbs": [
|
||||||
|
{
|
||||||
|
"name": kwargs["retriever"].vectorstore.__class__.__name__,
|
||||||
|
"embedding_model": str(
|
||||||
|
kwargs["retriever"].vectorstore._embeddings.model
|
||||||
|
)
|
||||||
|
if hasattr(kwargs["retriever"].vectorstore, "_embeddings")
|
||||||
|
else (
|
||||||
|
str(kwargs["retriever"].vectorstore._embedding.model)
|
||||||
|
if hasattr(kwargs["retriever"].vectorstore, "_embedding")
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
return chain
|
||||||
|
@ -60,3 +60,86 @@ class ChainInput(BaseModel):
|
|||||||
base_dict["auth_context"] = self.auth_context
|
base_dict["auth_context"] = self.auth_context
|
||||||
base_dict["semantic_context"] = self.semantic_context
|
base_dict["semantic_context"] = self.semantic_context
|
||||||
return base_dict
|
return base_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Runtime(BaseModel):
|
||||||
|
"""
|
||||||
|
OS, language details
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: Optional[str] = ""
|
||||||
|
host: str
|
||||||
|
path: str
|
||||||
|
ip: Optional[str] = ""
|
||||||
|
platform: str
|
||||||
|
os: str
|
||||||
|
os_version: str
|
||||||
|
language: str
|
||||||
|
language_version: str
|
||||||
|
runtime: Optional[str] = ""
|
||||||
|
|
||||||
|
|
||||||
|
class Framework(BaseModel):
|
||||||
|
"""
|
||||||
|
Langchain framework details
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
version: str
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
vendor: Optional[str]
|
||||||
|
name: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class PkgInfo(BaseModel):
|
||||||
|
project_home_page: Optional[str]
|
||||||
|
documentation_url: Optional[str]
|
||||||
|
pypi_url: Optional[str]
|
||||||
|
liscence_type: Optional[str]
|
||||||
|
installed_via: Optional[str]
|
||||||
|
location: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDB(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
version: Optional[str] = None
|
||||||
|
location: Optional[str] = None
|
||||||
|
embedding_model: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Chains(BaseModel):
|
||||||
|
name: str
|
||||||
|
model: Optional[Model]
|
||||||
|
vector_dbs: Optional[List[VectorDB]]
|
||||||
|
|
||||||
|
|
||||||
|
class App(BaseModel):
|
||||||
|
name: str
|
||||||
|
owner: str
|
||||||
|
description: Optional[str]
|
||||||
|
runtime: Runtime
|
||||||
|
framework: Framework
|
||||||
|
chains: List[Chains]
|
||||||
|
plugin_version: str
|
||||||
|
|
||||||
|
|
||||||
|
class Context(BaseModel):
|
||||||
|
retrieved_from: Optional[str]
|
||||||
|
doc: Optional[str]
|
||||||
|
vector_db: str
|
||||||
|
|
||||||
|
|
||||||
|
class Prompt(BaseModel):
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
class Qa(BaseModel):
|
||||||
|
name: str
|
||||||
|
context: List[Optional[Context]]
|
||||||
|
prompt: Prompt
|
||||||
|
response: Prompt
|
||||||
|
prompt_time: str
|
||||||
|
user: str
|
||||||
|
user_identities: Optional[List[str]]
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from langchain_core.env import get_runtime_environment
|
||||||
|
|
||||||
|
from langchain_community.chains.pebblo_retrieval.models import Framework, Runtime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
PLUGIN_VERSION = "0.1.1"
|
||||||
|
|
||||||
|
CLASSIFIER_URL = os.getenv("PEBBLO_CLASSIFIER_URL", "http://localhost:8000")
|
||||||
|
PEBBLO_CLOUD_URL = os.getenv("PEBBLO_CLOUD_URL", "https://api.daxa.ai")
|
||||||
|
|
||||||
|
PROMPT_URL = "/v1/prompt"
|
||||||
|
APP_DISCOVER_URL = "/v1/app/discover"
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime() -> Tuple[Framework, Runtime]:
|
||||||
|
"""Fetch the current Framework and Runtime details.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[Framework, Runtime]: Framework and Runtime for the current app instance.
|
||||||
|
"""
|
||||||
|
runtime_env = get_runtime_environment()
|
||||||
|
framework = Framework(
|
||||||
|
name="langchain", version=runtime_env.get("library_version", None)
|
||||||
|
)
|
||||||
|
uname = platform.uname()
|
||||||
|
runtime = Runtime(
|
||||||
|
host=uname.node,
|
||||||
|
path=os.environ["PWD"],
|
||||||
|
platform=runtime_env.get("platform", "unknown"),
|
||||||
|
os=uname.system,
|
||||||
|
os_version=uname.version,
|
||||||
|
ip=get_ip(),
|
||||||
|
language=runtime_env.get("runtime", "unknown"),
|
||||||
|
language_version=runtime_env.get("runtime_version", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "Darwin" in runtime.os:
|
||||||
|
runtime.type = "desktop"
|
||||||
|
runtime.runtime = "Mac OSX"
|
||||||
|
|
||||||
|
logger.debug(f"framework {framework}")
|
||||||
|
logger.debug(f"runtime {runtime}")
|
||||||
|
return framework, runtime
|
||||||
|
|
||||||
|
|
||||||
|
def get_ip() -> str:
|
||||||
|
"""Fetch local runtime ip address.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: IP address
|
||||||
|
"""
|
||||||
|
import socket # lazy imports
|
||||||
|
|
||||||
|
host = socket.gethostname()
|
||||||
|
try:
|
||||||
|
public_ip = socket.gethostbyname(host)
|
||||||
|
except Exception:
|
||||||
|
public_ip = socket.gethostbyname("localhost")
|
||||||
|
return public_ip
|
@ -72,7 +72,12 @@ def pebblo_retrieval_qa(retriever: FakeRetriever) -> PebbloRetrievalQA:
|
|||||||
Create a PebbloRetrievalQA instance
|
Create a PebbloRetrievalQA instance
|
||||||
"""
|
"""
|
||||||
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
pebblo_retrieval_qa = PebbloRetrievalQA.from_chain_type(
|
||||||
llm=FakeLLM(), chain_type="stuff", retriever=retriever
|
llm=FakeLLM(),
|
||||||
|
chain_type="stuff",
|
||||||
|
retriever=retriever,
|
||||||
|
owner="owner",
|
||||||
|
description="description",
|
||||||
|
app_name="app_name",
|
||||||
)
|
)
|
||||||
|
|
||||||
return pebblo_retrieval_qa
|
return pebblo_retrieval_qa
|
||||||
@ -114,6 +119,9 @@ def test_validate_vectorstore(
|
|||||||
llm=FakeLLM(),
|
llm=FakeLLM(),
|
||||||
chain_type="stuff",
|
chain_type="stuff",
|
||||||
retriever=retriever,
|
retriever=retriever,
|
||||||
|
owner="owner",
|
||||||
|
description="description",
|
||||||
|
app_name="app_name",
|
||||||
)
|
)
|
||||||
|
|
||||||
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
# validate_vectorstore method should raise a ValueError for unsupported vectorstores
|
||||||
@ -122,6 +130,9 @@ def test_validate_vectorstore(
|
|||||||
llm=FakeLLM(),
|
llm=FakeLLM(),
|
||||||
chain_type="stuff",
|
chain_type="stuff",
|
||||||
retriever=unsupported_retriever,
|
retriever=unsupported_retriever,
|
||||||
|
owner="owner",
|
||||||
|
description="description",
|
||||||
|
app_name="app_name",
|
||||||
)
|
)
|
||||||
assert (
|
assert (
|
||||||
"Vectorstore must be an instance of one of the supported vectorstores"
|
"Vectorstore must be an instance of one of the supported vectorstores"
|
||||||
|
Loading…
Reference in New Issue
Block a user