mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
community[minor]: add support for Guardrails for Amazon Bedrock (#15099)
Added support for optionally supplying 'Guardrails for Amazon Bedrock' on both types of model invocations (batch/regular and streaming) and for all models supported by the Amazon Bedrock service. @baskaryan @hwchase17 ```python llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, guardrails={"id": " <guardrail_id>", "version": "<guardrail_version>", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) class BedrockAsyncCallbackHandler(AsyncCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_error( self, error: BaseException, **kwargs: Any, ) -> Any: reason = kwargs.get("reason") if reason == "GUARDRAIL_INTERVENED": # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service. print(f"""Guardrails: {kwargs}""") # streaming llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, streaming=True, guardrails={"id": "<guardrail_id>", "version": "<guardrail_version>"}) ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
04651f0248
commit
a91181fe6d
@ -106,6 +106,45 @@
|
||||
"\n",
|
||||
"conversation.predict(input=\"Hi there!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Guardrails for Amazon Bedrock example \n",
|
||||
"\n",
|
||||
"In this section, we are going to set up a Bedrock language model with specific guardrails that include tracing capabilities. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import Any\n",
|
||||
"\n",
|
||||
"from langchain_core.callbacks import AsyncCallbackHandler\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class BedrockAsyncCallbackHandler(AsyncCallbackHandler):\n",
|
||||
" # Async callback handler that can be used to handle callbacks from langchain.\n",
|
||||
"\n",
|
||||
" async def on_llm_error(self, error: BaseException, **kwargs: Any) -> Any:\n",
|
||||
" reason = kwargs.get(\"reason\")\n",
|
||||
" if reason == \"GUARDRAIL_INTERVENED\":\n",
|
||||
" print(f\"Guardrails: {kwargs}\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# guardrails for Amazon Bedrock with trace\n",
|
||||
"llm = Bedrock(\n",
|
||||
" credentials_profile_name=\"bedrock-admin\",\n",
|
||||
" model_id=\"<Model_ID>\",\n",
|
||||
" model_kwargs={},\n",
|
||||
" guardrails={\"id\": \"<Guardrail_ID>\", \"version\": \"<Version>\", \"trace\": True},\n",
|
||||
" callbacks=[BedrockAsyncCallbackHandler()],\n",
|
||||
")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -34,6 +34,8 @@ from langchain_community.utilities.anthropic import (
|
||||
if TYPE_CHECKING:
|
||||
from botocore.config import Config
|
||||
|
||||
AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
|
||||
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
|
||||
HUMAN_PROMPT = "\n\nHuman:"
|
||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||
ALTERNATION_ERROR = (
|
||||
@ -117,21 +119,26 @@ class LLMInputOutputAdapter:
|
||||
return input_body
|
||||
|
||||
@classmethod
|
||||
def prepare_output(cls, provider: str, response: Any) -> str:
|
||||
def prepare_output(cls, provider: str, response: Any) -> dict:
|
||||
if provider == "anthropic":
|
||||
response_body = json.loads(response.get("body").read().decode())
|
||||
return response_body.get("completion")
|
||||
text = response_body.get("completion")
|
||||
else:
|
||||
response_body = json.loads(response.get("body").read())
|
||||
|
||||
if provider == "ai21":
|
||||
return response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "cohere":
|
||||
return response_body.get("generations")[0].get("text")
|
||||
elif provider == "meta":
|
||||
return response_body.get("generation")
|
||||
else:
|
||||
return response_body.get("results")[0].get("outputText")
|
||||
if provider == "ai21":
|
||||
text = response_body.get("completions")[0].get("data").get("text")
|
||||
elif provider == "cohere":
|
||||
text = response_body.get("generations")[0].get("text")
|
||||
elif provider == "meta":
|
||||
text = response_body.get("generation")
|
||||
else:
|
||||
text = response_body.get("results")[0].get("outputText")
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"body": response_body,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def prepare_output_stream(
|
||||
@ -160,8 +167,15 @@ class LLMInputOutputAdapter:
|
||||
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
|
||||
):
|
||||
return
|
||||
|
||||
yield GenerationChunk(text=chunk_obj[output_key])
|
||||
# chunk obj format varies with provider
|
||||
yield GenerationChunk(
|
||||
text=chunk_obj[output_key],
|
||||
generation_info={
|
||||
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
|
||||
if GUARDRAILS_BODY_KEY in chunk_obj
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def aprepare_output_stream(
|
||||
@ -235,6 +249,53 @@ class BedrockBase(BaseModel, ABC):
|
||||
"cohere": "stop_sequences",
|
||||
}
|
||||
|
||||
guardrails: Optional[Mapping[str, Any]] = {
|
||||
"id": None,
|
||||
"version": None,
|
||||
"trace": False,
|
||||
}
|
||||
"""
|
||||
An optional dictionary to configure guardrails for Bedrock.
|
||||
|
||||
This field 'guardrails' consists of two keys: 'id' and 'version',
|
||||
which should be strings, but are initialized to None. It's used to
|
||||
determine if specific guardrails are enabled and properly set.
|
||||
|
||||
Type:
|
||||
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.
|
||||
|
||||
Example:
|
||||
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": "<guardrail_id>",
|
||||
"version": "<guardrail_version>"})
|
||||
|
||||
To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
|
||||
'run_manager' parameter of the 'generate', '_call' methods.
|
||||
|
||||
Example:
|
||||
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": "<guardrail_id>",
|
||||
"version": "<guardrail_version>",
|
||||
"trace": True},
|
||||
callbacks=[BedrockAsyncCallbackHandler()])
|
||||
|
||||
[https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers.
|
||||
|
||||
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
reason = kwargs.get("reason")
|
||||
if reason == "GUARDRAIL_INTERVENED":
|
||||
...Logic to handle guardrail intervention...
|
||||
""" # noqa: E501
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that AWS credentials to and python package exists in environment."""
|
||||
@ -298,6 +359,47 @@ class BedrockBase(BaseModel, ABC):
|
||||
def _model_is_anthropic(self) -> bool:
|
||||
return self._get_provider() == "anthropic"
|
||||
|
||||
@property
|
||||
def _guardrails_enabled(self) -> bool:
|
||||
"""
|
||||
Determines if guardrails are enabled and correctly configured.
|
||||
Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys.
|
||||
Checks if 'guardrails.trace' is true.
|
||||
|
||||
Returns:
|
||||
bool: True if guardrails are correctly configured, False otherwise.
|
||||
Raises:
|
||||
TypeError: If 'guardrails' lacks 'id' or 'version' keys.
|
||||
"""
|
||||
try:
|
||||
return (
|
||||
isinstance(self.guardrails, dict)
|
||||
and bool(self.guardrails["id"])
|
||||
and bool(self.guardrails["version"])
|
||||
)
|
||||
|
||||
except KeyError as e:
|
||||
raise TypeError(
|
||||
"Guardrails must be a dictionary with 'id' and 'version' keys."
|
||||
) from e
|
||||
|
||||
def _get_guardrails_canonical(self) -> Dict[str, Any]:
|
||||
"""
|
||||
The canonical way to pass in guardrails to the bedrock service
|
||||
adheres to the following format:
|
||||
|
||||
"amazon-bedrock-guardrailDetails": {
|
||||
"guardrailId": "string",
|
||||
"guardrailVersion": "string"
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"amazon-bedrock-guardrailDetails": {
|
||||
"guardrailId": self.guardrails.get("id"),
|
||||
"guardrailVersion": self.guardrails.get("version"),
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_input_and_invoke(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -309,29 +411,81 @@ class BedrockBase(BaseModel, ABC):
|
||||
|
||||
provider = self._get_provider()
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
if self._guardrails_enabled:
|
||||
params.update(self._get_guardrails_canonical())
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
accept = "application/json"
|
||||
contentType = "application/json"
|
||||
|
||||
request_options = {
|
||||
"body": body,
|
||||
"modelId": self.model_id,
|
||||
"accept": accept,
|
||||
"contentType": contentType,
|
||||
}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
request_options["guardrail"] = "ENABLED"
|
||||
if self.guardrails.get("trace"):
|
||||
request_options["trace"] = "ENABLED"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model(
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept=accept,
|
||||
contentType=contentType,
|
||||
)
|
||||
text = LLMInputOutputAdapter.prepare_output(provider, response)
|
||||
response = self.client.invoke_model(**request_options)
|
||||
|
||||
text, body = LLMInputOutputAdapter.prepare_output(
|
||||
provider, response
|
||||
).values()
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
|
||||
e.__traceback__
|
||||
)
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
# Verify and raise a callback error if any intervention occurs or a signal is
|
||||
# sent from a Bedrock service,
|
||||
# such as when guardrails are triggered.
|
||||
services_trace = self._get_bedrock_services_signal(body)
|
||||
|
||||
if services_trace.get("signal") and run_manager is not None:
|
||||
run_manager.on_llm_error(
|
||||
Exception(
|
||||
f"Error raised by bedrock service: {services_trace.get('reason')}"
|
||||
),
|
||||
**services_trace,
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def _get_bedrock_services_signal(self, body: dict) -> dict:
|
||||
"""
|
||||
This function checks the response body for an interrupt flag or message that indicates
|
||||
whether any of the Bedrock services have intervened in the processing flow. It is
|
||||
primarily used to identify modifications or interruptions imposed by these services
|
||||
during the request-response cycle with a Large Language Model (LLM).
|
||||
""" # noqa: E501
|
||||
|
||||
if (
|
||||
self._guardrails_enabled
|
||||
and self.guardrails.get("trace")
|
||||
and self._is_guardrails_intervention(body)
|
||||
):
|
||||
return {
|
||||
"signal": True,
|
||||
"reason": "GUARDRAIL_INTERVENED",
|
||||
"trace": body.get(AMAZON_BEDROCK_TRACE_KEY),
|
||||
}
|
||||
|
||||
return {
|
||||
"signal": False,
|
||||
"reason": None,
|
||||
"trace": None,
|
||||
}
|
||||
|
||||
def _is_guardrails_intervention(self, body: dict) -> bool:
|
||||
return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED"
|
||||
|
||||
def _prepare_input_and_invoke_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -356,16 +510,28 @@ class BedrockBase(BaseModel, ABC):
|
||||
_model_kwargs["stream"] = True
|
||||
|
||||
params = {**_model_kwargs, **kwargs}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
params.update(self._get_guardrails_canonical())
|
||||
|
||||
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
|
||||
body = json.dumps(input_body)
|
||||
|
||||
request_options = {
|
||||
"body": body,
|
||||
"modelId": self.model_id,
|
||||
"accept": "application/json",
|
||||
"contentType": "application/json",
|
||||
}
|
||||
|
||||
if self._guardrails_enabled:
|
||||
request_options["guardrail"] = "ENABLED"
|
||||
if self.guardrails.get("trace"):
|
||||
request_options["trace"] = "ENABLED"
|
||||
|
||||
try:
|
||||
response = self.client.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=self.model_id,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
response = self.client.invoke_model_with_response_stream(**request_options)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error raised by bedrock service: {e}")
|
||||
|
||||
@ -373,6 +539,9 @@ class BedrockBase(BaseModel, ABC):
|
||||
provider, response, stop
|
||||
):
|
||||
yield chunk
|
||||
# verify and raise callback error if any middleware intervened
|
||||
self._get_bedrock_services_signal(chunk.generation_info)
|
||||
|
||||
if run_manager is not None:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
|
||||
@ -536,7 +705,9 @@ class Bedrock(LLM, BedrockBase):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
|
||||
return self._prepare_input_and_invoke(
|
||||
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
|
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""
|
||||
Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'.
|
||||
You can get a list of models from the bedrock client by running 'bedrock_models()'
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from langchain_community.llms.bedrock import Bedrock
|
||||
|
||||
# this is the guardrails id for the model you want to test
|
||||
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
|
||||
# this is the guardrails version for the model you want to test
|
||||
GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1")
|
||||
# this should trigger the guardrails - you can change this to any text you want which
|
||||
# will trigger the guardrails
|
||||
GUARDRAILS_TRIGGER = os.environ.get(
|
||||
"GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics."
|
||||
)
|
||||
|
||||
|
||||
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
guardrails_intervened = False
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
reason = kwargs.get("reason")
|
||||
if reason == "GUARDRAIL_INTERVENED":
|
||||
self.guardrails_intervened = True
|
||||
|
||||
def get_response(self):
|
||||
return self.guardrails_intervened
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_runtime_client():
|
||||
import boto3
|
||||
|
||||
try:
|
||||
client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_client():
|
||||
import boto3
|
||||
|
||||
try:
|
||||
client = boto3.client(
|
||||
"bedrock",
|
||||
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bedrock_models(bedrock_client):
|
||||
"""List bedrock models."""
|
||||
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
||||
models = {}
|
||||
for model in response:
|
||||
models[model.get("modelId")] = model.get("modelName")
|
||||
return models
|
||||
|
||||
|
||||
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
)
|
||||
output = llm("Say something positive:")
|
||||
assert isinstance(output, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": GUARDRAILS_ID,
|
||||
"version": GUARDRAILS_VERSION,
|
||||
"trace": False,
|
||||
},
|
||||
)
|
||||
output = llm("Say something positive:")
|
||||
assert isinstance(output, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
try:
|
||||
handler = BedrockAsyncCallbackHandler()
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": GUARDRAILS_ID,
|
||||
"version": GUARDRAILS_VERSION,
|
||||
"trace": True,
|
||||
},
|
||||
callbacks=[handler],
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
else:
|
||||
llm(GUARDRAILS_TRIGGER)
|
||||
guardrails_intervened = handler.get_response()
|
||||
assert guardrails_intervened is True
|
Loading…
Reference in New Issue
Block a user