community[minor]: add support for Guardrails for Amazon Bedrock (#15099)

Added support for optionally supplying 'Guardrails for Amazon Bedrock'
on both types of model invocations (batch/regular and streaming) and for
all models supported by the Amazon Bedrock service.

@baskaryan  @hwchase17

```python 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  guardrails={"id": " <guardrail_id>",
                              "version": "<guardrail_version>",
                               "trace": True}, callbacks=[BedrockAsyncCallbackHandler()])

class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
    """Async callback handler that can be used to handle callbacks from langchain."""

    async def on_llm_error(
            self,
            error: BaseException,
            **kwargs: Any,
    ) -> Any:
        reason = kwargs.get("reason")
        if reason == "GUARDRAIL_INTERVENED":
           # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service.
            print(f"""Guardrails: {kwargs}""")


# streaming 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  streaming=True,
                  guardrails={"id": "<guardrail_id>",
                              "version": "<guardrail_version>"})
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harel Gal 2024-01-25 00:44:19 +02:00 committed by GitHub
parent 04651f0248
commit a91181fe6d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 375 additions and 29 deletions

View File

@ -106,6 +106,45 @@
"\n",
"conversation.predict(input=\"Hi there!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Guardrails for Amazon Bedrock example \n",
"\n",
"In this section, we are going to set up a Bedrock language model with specific guardrails that include tracing capabilities. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"from langchain_core.callbacks import AsyncCallbackHandler\n",
"\n",
"\n",
"class BedrockAsyncCallbackHandler(AsyncCallbackHandler):\n",
" # Async callback handler that can be used to handle callbacks from langchain.\n",
"\n",
" async def on_llm_error(self, error: BaseException, **kwargs: Any) -> Any:\n",
" reason = kwargs.get(\"reason\")\n",
" if reason == \"GUARDRAIL_INTERVENED\":\n",
" print(f\"Guardrails: {kwargs}\")\n",
"\n",
"\n",
"# guardrails for Amazon Bedrock with trace\n",
"llm = Bedrock(\n",
" credentials_profile_name=\"bedrock-admin\",\n",
" model_id=\"<Model_ID>\",\n",
" model_kwargs={},\n",
" guardrails={\"id\": \"<Guardrail_ID>\", \"version\": \"<Version>\", \"trace\": True},\n",
" callbacks=[BedrockAsyncCallbackHandler()],\n",
")"
]
}
],
"metadata": {

View File

@ -34,6 +34,8 @@ from langchain_community.utilities.anthropic import (
if TYPE_CHECKING:
from botocore.config import Config
AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
ALTERNATION_ERROR = (
@ -117,21 +119,26 @@ class LLMInputOutputAdapter:
return input_body
@classmethod
def prepare_output(cls, provider: str, response: Any) -> str:
def prepare_output(cls, provider: str, response: Any) -> dict:
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
return response_body.get("completion")
text = response_body.get("completion")
else:
response_body = json.loads(response.get("body").read())
if provider == "ai21":
return response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
return response_body.get("generations")[0].get("text")
elif provider == "meta":
return response_body.get("generation")
else:
return response_body.get("results")[0].get("outputText")
if provider == "ai21":
text = response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
text = response_body.get("generations")[0].get("text")
elif provider == "meta":
text = response_body.get("generation")
else:
text = response_body.get("results")[0].get("outputText")
return {
"text": text,
"body": response_body,
}
@classmethod
def prepare_output_stream(
@ -160,8 +167,15 @@ class LLMInputOutputAdapter:
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
return
yield GenerationChunk(text=chunk_obj[output_key])
# chunk obj format varies with provider
yield GenerationChunk(
text=chunk_obj[output_key],
generation_info={
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
if GUARDRAILS_BODY_KEY in chunk_obj
else None,
},
)
@classmethod
async def aprepare_output_stream(
@ -235,6 +249,53 @@ class BedrockBase(BaseModel, ABC):
"cohere": "stop_sequences",
}
guardrails: Optional[Mapping[str, Any]] = {
"id": None,
"version": None,
"trace": False,
}
"""
An optional dictionary to configure guardrails for Bedrock.
This field 'guardrails' consists of two keys: 'id' and 'version',
which should be strings, but are initialized to None. It's used to
determine if specific guardrails are enabled and properly set.
Type:
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.
Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>"})
To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
'run_manager' parameter of the 'generate', '_call' methods.
Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>",
"trace": True},
callbacks=[BedrockAsyncCallbackHandler()])
[https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers.
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
async def on_llm_error(
self,
error: BaseException,
**kwargs: Any,
) -> Any:
reason = kwargs.get("reason")
if reason == "GUARDRAIL_INTERVENED":
...Logic to handle guardrail intervention...
""" # noqa: E501
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
@ -298,6 +359,47 @@ class BedrockBase(BaseModel, ABC):
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"
@property
def _guardrails_enabled(self) -> bool:
"""
Determines if guardrails are enabled and correctly configured.
Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys.
Checks if 'guardrails.trace' is true.
Returns:
bool: True if guardrails are correctly configured, False otherwise.
Raises:
TypeError: If 'guardrails' lacks 'id' or 'version' keys.
"""
try:
return (
isinstance(self.guardrails, dict)
and bool(self.guardrails["id"])
and bool(self.guardrails["version"])
)
except KeyError as e:
raise TypeError(
"Guardrails must be a dictionary with 'id' and 'version' keys."
) from e
def _get_guardrails_canonical(self) -> Dict[str, Any]:
"""
The canonical way to pass in guardrails to the bedrock service
adheres to the following format:
"amazon-bedrock-guardrailDetails": {
"guardrailId": "string",
"guardrailVersion": "string"
}
"""
return {
"amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"),
"guardrailVersion": self.guardrails.get("version"),
}
}
def _prepare_input_and_invoke(
self,
prompt: str,
@ -309,29 +411,81 @@ class BedrockBase(BaseModel, ABC):
provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
request_options = {
"body": body,
"modelId": self.model_id,
"accept": accept,
"contentType": contentType,
}
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
request_options["trace"] = "ENABLED"
try:
response = self.client.invoke_model(
body=body,
modelId=self.model_id,
accept=accept,
contentType=contentType,
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
response = self.client.invoke_model(**request_options)
text, body = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
e.__traceback__
)
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
# such as when guardrails are triggered.
services_trace = self._get_bedrock_services_signal(body)
if services_trace.get("signal") and run_manager is not None:
run_manager.on_llm_error(
Exception(
f"Error raised by bedrock service: {services_trace.get('reason')}"
),
**services_trace,
)
return text
def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
This function checks the response body for an interrupt flag or message that indicates
whether any of the Bedrock services have intervened in the processing flow. It is
primarily used to identify modifications or interruptions imposed by these services
during the request-response cycle with a Large Language Model (LLM).
""" # noqa: E501
if (
self._guardrails_enabled
and self.guardrails.get("trace")
and self._is_guardrails_intervention(body)
):
return {
"signal": True,
"reason": "GUARDRAIL_INTERVENED",
"trace": body.get(AMAZON_BEDROCK_TRACE_KEY),
}
return {
"signal": False,
"reason": None,
"trace": None,
}
def _is_guardrails_intervention(self, body: dict) -> bool:
return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED"
def _prepare_input_and_invoke_stream(
self,
prompt: str,
@ -356,16 +510,28 @@ class BedrockBase(BaseModel, ABC):
_model_kwargs["stream"] = True
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
request_options = {
"body": body,
"modelId": self.model_id,
"accept": "application/json",
"contentType": "application/json",
}
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
request_options["trace"] = "ENABLED"
try:
response = self.client.invoke_model_with_response_stream(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
)
response = self.client.invoke_model_with_response_stream(**request_options)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
@ -373,6 +539,9 @@ class BedrockBase(BaseModel, ABC):
provider, response, stop
):
yield chunk
# verify and raise callback error if any middleware intervened
self._get_bedrock_services_signal(chunk.generation_info)
if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
@ -536,7 +705,9 @@ class Bedrock(LLM, BedrockBase):
completion += chunk.text
return completion
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
return self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
async def _astream(
self,

View File

@ -0,0 +1,136 @@
"""
Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'.
You can get a list of models from the bedrock client by running 'bedrock_models()'
"""
import os
from typing import Any
import pytest
from langchain_core.callbacks import AsyncCallbackHandler
from langchain_community.llms.bedrock import Bedrock
# this is the guardrails id for the model you want to test
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
# this is the guardrails version for the model you want to test
GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1")
# this should trigger the guardrails - you can change this to any text you want which
# will trigger the guardrails
GUARDRAILS_TRIGGER = os.environ.get(
"GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics."
)
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
"""Async callback handler that can be used to handle callbacks from langchain."""
guardrails_intervened = False
async def on_llm_error(
self,
error: BaseException,
**kwargs: Any,
) -> Any:
reason = kwargs.get("reason")
if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True
def get_response(self):
return self.guardrails_intervened
@pytest.fixture(autouse=True)
def bedrock_runtime_client():
import boto3
try:
client = boto3.client(
"bedrock-runtime",
region_name=os.environ.get("AWS_REGION", "us-east-1"),
)
return client
except Exception as e:
pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False)
@pytest.fixture(autouse=True)
def bedrock_client():
import boto3
try:
client = boto3.client(
"bedrock",
region_name=os.environ.get("AWS_REGION", "us-east-1"),
)
return client
except Exception as e:
pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False)
@pytest.fixture
def bedrock_models(bedrock_client):
"""List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {}
for model in response:
models[model.get("modelId")] = model.get("modelName")
return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
)
output = llm("Say something positive:")
assert isinstance(output, str)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
bedrock_runtime_client, bedrock_models
):
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
guardrails={
"id": GUARDRAILS_ID,
"version": GUARDRAILS_VERSION,
"trace": False,
},
)
output = llm("Say something positive:")
assert isinstance(output, str)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
bedrock_runtime_client, bedrock_models
):
try:
handler = BedrockAsyncCallbackHandler()
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
client=bedrock_runtime_client,
model_kwargs={},
guardrails={
"id": GUARDRAILS_ID,
"version": GUARDRAILS_VERSION,
"trace": True,
},
callbacks=[handler],
)
except Exception as e:
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
else:
llm(GUARDRAILS_TRIGGER)
guardrails_intervened = handler.get_response()
assert guardrails_intervened is True