mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
community[minor]: add support for Guardrails for Amazon Bedrock (#15099)
Added support for optionally supplying 'Guardrails for Amazon Bedrock' on both types of model invocations (batch/regular and streaming) and for all models supported by the Amazon Bedrock service. @baskaryan @hwchase17 ```python llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, guardrails={"id": " <guardrail_id>", "version": "<guardrail_version>", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) class BedrockAsyncCallbackHandler(AsyncCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_error( self, error: BaseException, **kwargs: Any, ) -> Any: reason = kwargs.get("reason") if reason == "GUARDRAIL_INTERVENED": # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service. print(f"""Guardrails: {kwargs}""") # streaming llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, streaming=True, guardrails={"id": "<guardrail_id>", "version": "<guardrail_version>"}) ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
04651f0248
commit
a91181fe6d
@ -106,6 +106,45 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"conversation.predict(input=\"Hi there!\")"
|
"conversation.predict(input=\"Hi there!\")"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Guardrails for Amazon Bedrock example \n",
|
||||||
|
"\n",
|
||||||
|
"In this section, we are going to set up a Bedrock language model with specific guardrails that include tracing capabilities. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from typing import Any\n",
|
||||||
|
"\n",
|
||||||
|
"from langchain_core.callbacks import AsyncCallbackHandler\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"class BedrockAsyncCallbackHandler(AsyncCallbackHandler):\n",
|
||||||
|
" # Async callback handler that can be used to handle callbacks from langchain.\n",
|
||||||
|
"\n",
|
||||||
|
" async def on_llm_error(self, error: BaseException, **kwargs: Any) -> Any:\n",
|
||||||
|
" reason = kwargs.get(\"reason\")\n",
|
||||||
|
" if reason == \"GUARDRAIL_INTERVENED\":\n",
|
||||||
|
" print(f\"Guardrails: {kwargs}\")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# guardrails for Amazon Bedrock with trace\n",
|
||||||
|
"llm = Bedrock(\n",
|
||||||
|
" credentials_profile_name=\"bedrock-admin\",\n",
|
||||||
|
" model_id=\"<Model_ID>\",\n",
|
||||||
|
" model_kwargs={},\n",
|
||||||
|
" guardrails={\"id\": \"<Guardrail_ID>\", \"version\": \"<Version>\", \"trace\": True},\n",
|
||||||
|
" callbacks=[BedrockAsyncCallbackHandler()],\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -34,6 +34,8 @@ from langchain_community.utilities.anthropic import (
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from botocore.config import Config
|
from botocore.config import Config
|
||||||
|
|
||||||
|
AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
|
||||||
|
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
|
||||||
HUMAN_PROMPT = "\n\nHuman:"
|
HUMAN_PROMPT = "\n\nHuman:"
|
||||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||||
ALTERNATION_ERROR = (
|
ALTERNATION_ERROR = (
|
||||||
@ -117,21 +119,26 @@ class LLMInputOutputAdapter:
|
|||||||
return input_body
|
return input_body
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_output(cls, provider: str, response: Any) -> str:
|
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||||
if provider == "anthropic":
|
if provider == "anthropic":
|
||||||
response_body = json.loads(response.get("body").read().decode())
|
response_body = json.loads(response.get("body").read().decode())
|
||||||
return response_body.get("completion")
|
text = response_body.get("completion")
|
||||||
else:
|
else:
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
|
|
||||||
if provider == "ai21":
|
if provider == "ai21":
|
||||||
return response_body.get("completions")[0].get("data").get("text")
|
text = response_body.get("completions")[0].get("data").get("text")
|
||||||
elif provider == "cohere":
|
elif provider == "cohere":
|
||||||
return response_body.get("generations")[0].get("text")
|
text = response_body.get("generations")[0].get("text")
|
||||||
elif provider == "meta":
|
elif provider == "meta":
|
||||||
return response_body.get("generation")
|
text = response_body.get("generation")
|
||||||
else:
|
else:
|
||||||
return response_body.get("results")[0].get("outputText")
|
text = response_body.get("results")[0].get("outputText")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"body": response_body,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def prepare_output_stream(
|
def prepare_output_stream(
|
||||||
@ -160,8 +167,15 @@ class LLMInputOutputAdapter:
|
|||||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
# chunk obj format varies with provider
|
||||||
yield GenerationChunk(text=chunk_obj[output_key])
|
yield GenerationChunk(
|
||||||
|
text=chunk_obj[output_key],
|
||||||
|
generation_info={
|
||||||
|
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||||
|
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def aprepare_output_stream(
|
async def aprepare_output_stream(
|
||||||
@ -235,6 +249,53 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
"cohere": "stop_sequences",
|
"cohere": "stop_sequences",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
guardrails: Optional[Mapping[str, Any]] = {
|
||||||
|
"id": None,
|
||||||
|
"version": None,
|
||||||
|
"trace": False,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
An optional dictionary to configure guardrails for Bedrock.
|
||||||
|
|
||||||
|
This field 'guardrails' consists of two keys: 'id' and 'version',
|
||||||
|
which should be strings, but are initialized to None. It's used to
|
||||||
|
determine if specific guardrails are enabled and properly set.
|
||||||
|
|
||||||
|
Type:
|
||||||
|
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||||
|
model_kwargs={},
|
||||||
|
guardrails={
|
||||||
|
"id": "<guardrail_id>",
|
||||||
|
"version": "<guardrail_version>"})
|
||||||
|
|
||||||
|
To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
|
||||||
|
'run_manager' parameter of the 'generate', '_call' methods.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||||
|
model_kwargs={},
|
||||||
|
guardrails={
|
||||||
|
"id": "<guardrail_id>",
|
||||||
|
"version": "<guardrail_version>",
|
||||||
|
"trace": True},
|
||||||
|
callbacks=[BedrockAsyncCallbackHandler()])
|
||||||
|
|
||||||
|
[https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers.
|
||||||
|
|
||||||
|
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||||
|
async def on_llm_error(
|
||||||
|
self,
|
||||||
|
error: BaseException,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
reason = kwargs.get("reason")
|
||||||
|
if reason == "GUARDRAIL_INTERVENED":
|
||||||
|
...Logic to handle guardrail intervention...
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that AWS credentials to and python package exists in environment."""
|
"""Validate that AWS credentials to and python package exists in environment."""
|
||||||
@ -298,6 +359,47 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
def _model_is_anthropic(self) -> bool:
|
def _model_is_anthropic(self) -> bool:
|
||||||
return self._get_provider() == "anthropic"
|
return self._get_provider() == "anthropic"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _guardrails_enabled(self) -> bool:
|
||||||
|
"""
|
||||||
|
Determines if guardrails are enabled and correctly configured.
|
||||||
|
Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys.
|
||||||
|
Checks if 'guardrails.trace' is true.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if guardrails are correctly configured, False otherwise.
|
||||||
|
Raises:
|
||||||
|
TypeError: If 'guardrails' lacks 'id' or 'version' keys.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return (
|
||||||
|
isinstance(self.guardrails, dict)
|
||||||
|
and bool(self.guardrails["id"])
|
||||||
|
and bool(self.guardrails["version"])
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyError as e:
|
||||||
|
raise TypeError(
|
||||||
|
"Guardrails must be a dictionary with 'id' and 'version' keys."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _get_guardrails_canonical(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
The canonical way to pass in guardrails to the bedrock service
|
||||||
|
adheres to the following format:
|
||||||
|
|
||||||
|
"amazon-bedrock-guardrailDetails": {
|
||||||
|
"guardrailId": "string",
|
||||||
|
"guardrailVersion": "string"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"amazon-bedrock-guardrailDetails": {
|
||||||
|
"guardrailId": self.guardrails.get("id"),
|
||||||
|
"guardrailVersion": self.guardrails.get("version"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def _prepare_input_and_invoke(
|
def _prepare_input_and_invoke(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -309,29 +411,81 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
|
|
||||||
provider = self._get_provider()
|
provider = self._get_provider()
|
||||||
params = {**_model_kwargs, **kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
if self._guardrails_enabled:
|
||||||
|
params.update(self._get_guardrails_canonical())
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
contentType = "application/json"
|
contentType = "application/json"
|
||||||
|
|
||||||
|
request_options = {
|
||||||
|
"body": body,
|
||||||
|
"modelId": self.model_id,
|
||||||
|
"accept": accept,
|
||||||
|
"contentType": contentType,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._guardrails_enabled:
|
||||||
|
request_options["guardrail"] = "ENABLED"
|
||||||
|
if self.guardrails.get("trace"):
|
||||||
|
request_options["trace"] = "ENABLED"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.invoke_model(
|
response = self.client.invoke_model(**request_options)
|
||||||
body=body,
|
|
||||||
modelId=self.model_id,
|
text, body = LLMInputOutputAdapter.prepare_output(
|
||||||
accept=accept,
|
provider, response
|
||||||
contentType=contentType,
|
).values()
|
||||||
)
|
|
||||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||||
e.__traceback__
|
|
||||||
)
|
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
text = enforce_stop_tokens(text, stop)
|
text = enforce_stop_tokens(text, stop)
|
||||||
|
|
||||||
|
# Verify and raise a callback error if any intervention occurs or a signal is
|
||||||
|
# sent from a Bedrock service,
|
||||||
|
# such as when guardrails are triggered.
|
||||||
|
services_trace = self._get_bedrock_services_signal(body)
|
||||||
|
|
||||||
|
if services_trace.get("signal") and run_manager is not None:
|
||||||
|
run_manager.on_llm_error(
|
||||||
|
Exception(
|
||||||
|
f"Error raised by bedrock service: {services_trace.get('reason')}"
|
||||||
|
),
|
||||||
|
**services_trace,
|
||||||
|
)
|
||||||
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def _get_bedrock_services_signal(self, body: dict) -> dict:
|
||||||
|
"""
|
||||||
|
This function checks the response body for an interrupt flag or message that indicates
|
||||||
|
whether any of the Bedrock services have intervened in the processing flow. It is
|
||||||
|
primarily used to identify modifications or interruptions imposed by these services
|
||||||
|
during the request-response cycle with a Large Language Model (LLM).
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._guardrails_enabled
|
||||||
|
and self.guardrails.get("trace")
|
||||||
|
and self._is_guardrails_intervention(body)
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"signal": True,
|
||||||
|
"reason": "GUARDRAIL_INTERVENED",
|
||||||
|
"trace": body.get(AMAZON_BEDROCK_TRACE_KEY),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"signal": False,
|
||||||
|
"reason": None,
|
||||||
|
"trace": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _is_guardrails_intervention(self, body: dict) -> bool:
|
||||||
|
return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED"
|
||||||
|
|
||||||
def _prepare_input_and_invoke_stream(
|
def _prepare_input_and_invoke_stream(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -356,16 +510,28 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
_model_kwargs["stream"] = True
|
_model_kwargs["stream"] = True
|
||||||
|
|
||||||
params = {**_model_kwargs, **kwargs}
|
params = {**_model_kwargs, **kwargs}
|
||||||
|
|
||||||
|
if self._guardrails_enabled:
|
||||||
|
params.update(self._get_guardrails_canonical())
|
||||||
|
|
||||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||||
body = json.dumps(input_body)
|
body = json.dumps(input_body)
|
||||||
|
|
||||||
|
request_options = {
|
||||||
|
"body": body,
|
||||||
|
"modelId": self.model_id,
|
||||||
|
"accept": "application/json",
|
||||||
|
"contentType": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._guardrails_enabled:
|
||||||
|
request_options["guardrail"] = "ENABLED"
|
||||||
|
if self.guardrails.get("trace"):
|
||||||
|
request_options["trace"] = "ENABLED"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.client.invoke_model_with_response_stream(
|
response = self.client.invoke_model_with_response_stream(**request_options)
|
||||||
body=body,
|
|
||||||
modelId=self.model_id,
|
|
||||||
accept="application/json",
|
|
||||||
contentType="application/json",
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||||
|
|
||||||
@ -373,6 +539,9 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
provider, response, stop
|
provider, response, stop
|
||||||
):
|
):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
# verify and raise callback error if any middleware intervened
|
||||||
|
self._get_bedrock_services_signal(chunk.generation_info)
|
||||||
|
|
||||||
if run_manager is not None:
|
if run_manager is not None:
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
@ -536,7 +705,9 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
completion += chunk.text
|
completion += chunk.text
|
||||||
return completion
|
return completion
|
||||||
|
|
||||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
return self._prepare_input_and_invoke(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
async def _astream(
|
async def _astream(
|
||||||
self,
|
self,
|
||||||
|
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'.
|
||||||
|
You can get a list of models from the bedrock client by running 'bedrock_models()'
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import AsyncCallbackHandler
|
||||||
|
|
||||||
|
from langchain_community.llms.bedrock import Bedrock
|
||||||
|
|
||||||
|
# this is the guardrails id for the model you want to test
|
||||||
|
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
|
||||||
|
# this is the guardrails version for the model you want to test
|
||||||
|
GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1")
|
||||||
|
# this should trigger the guardrails - you can change this to any text you want which
|
||||||
|
# will trigger the guardrails
|
||||||
|
GUARDRAILS_TRIGGER = os.environ.get(
|
||||||
|
"GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||||
|
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||||
|
|
||||||
|
guardrails_intervened = False
|
||||||
|
|
||||||
|
async def on_llm_error(
|
||||||
|
self,
|
||||||
|
error: BaseException,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
reason = kwargs.get("reason")
|
||||||
|
if reason == "GUARDRAIL_INTERVENED":
|
||||||
|
self.guardrails_intervened = True
|
||||||
|
|
||||||
|
def get_response(self):
|
||||||
|
return self.guardrails_intervened
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def bedrock_runtime_client():
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = boto3.client(
|
||||||
|
"bedrock-runtime",
|
||||||
|
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def bedrock_client():
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
try:
|
||||||
|
client = boto3.client(
|
||||||
|
"bedrock",
|
||||||
|
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bedrock_models(bedrock_client):
|
||||||
|
"""List bedrock models."""
|
||||||
|
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
||||||
|
models = {}
|
||||||
|
for model in response:
|
||||||
|
models[model.get("modelId")] = model.get("modelName")
|
||||||
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
|
||||||
|
try:
|
||||||
|
llm = Bedrock(
|
||||||
|
model_id="anthropic.claude-instant-v1",
|
||||||
|
client=bedrock_runtime_client,
|
||||||
|
model_kwargs={},
|
||||||
|
)
|
||||||
|
output = llm("Say something positive:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
||||||
|
bedrock_runtime_client, bedrock_models
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
llm = Bedrock(
|
||||||
|
model_id="anthropic.claude-instant-v1",
|
||||||
|
client=bedrock_runtime_client,
|
||||||
|
model_kwargs={},
|
||||||
|
guardrails={
|
||||||
|
"id": GUARDRAILS_ID,
|
||||||
|
"version": GUARDRAILS_VERSION,
|
||||||
|
"trace": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
output = llm("Say something positive:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||||
|
|
||||||
|
|
||||||
|
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
|
||||||
|
bedrock_runtime_client, bedrock_models
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
handler = BedrockAsyncCallbackHandler()
|
||||||
|
llm = Bedrock(
|
||||||
|
model_id="anthropic.claude-instant-v1",
|
||||||
|
client=bedrock_runtime_client,
|
||||||
|
model_kwargs={},
|
||||||
|
guardrails={
|
||||||
|
"id": GUARDRAILS_ID,
|
||||||
|
"version": GUARDRAILS_VERSION,
|
||||||
|
"trace": True,
|
||||||
|
},
|
||||||
|
callbacks=[handler],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||||
|
else:
|
||||||
|
llm(GUARDRAILS_TRIGGER)
|
||||||
|
guardrails_intervened = handler.get_response()
|
||||||
|
assert guardrails_intervened is True
|
Loading…
Reference in New Issue
Block a user