community[minor]: add support for Guardrails for Amazon Bedrock (#15099)

Added support for optionally supplying 'Guardrails for Amazon Bedrock'
on both types of model invocations (batch/regular and streaming) and for
all models supported by the Amazon Bedrock service.

@baskaryan  @hwchase17

```python 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  guardrails={"id": " <guardrail_id>",
                              "version": "<guardrail_version>",
                               "trace": True}, callbacks=[BedrockAsyncCallbackHandler()])

class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
    """Async callback handler that can be used to handle callbacks from langchain."""

    async def on_llm_error(
            self,
            error: BaseException,
            **kwargs: Any,
    ) -> Any:
        reason = kwargs.get("reason")
        if reason == "GUARDRAIL_INTERVENED":
           # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service.
            print(f"""Guardrails: {kwargs}""")


# streaming 
llm = Bedrock(model_id="<model_id>", client=bedrock,
                  model_kwargs={},
                  streaming=True,
                  guardrails={"id": "<guardrail_id>",
                              "version": "<guardrail_version>"})
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Harel Gal
2024-01-25 00:44:19 +02:00
committed by GitHub
parent 04651f0248
commit a91181fe6d
3 changed files with 375 additions and 29 deletions

View File

@@ -34,6 +34,8 @@ from langchain_community.utilities.anthropic import (
if TYPE_CHECKING:
from botocore.config import Config
AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace"
GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment"
HUMAN_PROMPT = "\n\nHuman:"
ASSISTANT_PROMPT = "\n\nAssistant:"
ALTERNATION_ERROR = (
@@ -117,21 +119,26 @@ class LLMInputOutputAdapter:
return input_body
@classmethod
def prepare_output(cls, provider: str, response: Any) -> str:
def prepare_output(cls, provider: str, response: Any) -> dict:
if provider == "anthropic":
response_body = json.loads(response.get("body").read().decode())
return response_body.get("completion")
text = response_body.get("completion")
else:
response_body = json.loads(response.get("body").read())
if provider == "ai21":
return response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
return response_body.get("generations")[0].get("text")
elif provider == "meta":
return response_body.get("generation")
else:
return response_body.get("results")[0].get("outputText")
if provider == "ai21":
text = response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
text = response_body.get("generations")[0].get("text")
elif provider == "meta":
text = response_body.get("generation")
else:
text = response_body.get("results")[0].get("outputText")
return {
"text": text,
"body": response_body,
}
@classmethod
def prepare_output_stream(
@@ -160,8 +167,15 @@ class LLMInputOutputAdapter:
chunk_obj["is_finished"] or chunk_obj[output_key] == "<EOS_TOKEN>"
):
return
yield GenerationChunk(text=chunk_obj[output_key])
# chunk obj format varies with provider
yield GenerationChunk(
text=chunk_obj[output_key],
generation_info={
GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY)
if GUARDRAILS_BODY_KEY in chunk_obj
else None,
},
)
@classmethod
async def aprepare_output_stream(
@@ -235,6 +249,53 @@ class BedrockBase(BaseModel, ABC):
"cohere": "stop_sequences",
}
guardrails: Optional[Mapping[str, Any]] = {
"id": None,
"version": None,
"trace": False,
}
"""
An optional dictionary to configure guardrails for Bedrock.
This field 'guardrails' consists of two keys: 'id' and 'version',
which should be strings, but are initialized to None. It's used to
determine if specific guardrails are enabled and properly set.
Type:
Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys.
Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>"})
To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the
'run_manager' parameter of the 'generate', '_call' methods.
Example:
llm = Bedrock(model_id="<model_id>", client=<bedrock_client>,
model_kwargs={},
guardrails={
"id": "<guardrail_id>",
"version": "<guardrail_version>",
"trace": True},
callbacks=[BedrockAsyncCallbackHandler()])
[https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers.
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
async def on_llm_error(
self,
error: BaseException,
**kwargs: Any,
) -> Any:
reason = kwargs.get("reason")
if reason == "GUARDRAIL_INTERVENED":
...Logic to handle guardrail intervention...
""" # noqa: E501
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that AWS credentials to and python package exists in environment."""
@@ -298,6 +359,47 @@ class BedrockBase(BaseModel, ABC):
def _model_is_anthropic(self) -> bool:
return self._get_provider() == "anthropic"
@property
def _guardrails_enabled(self) -> bool:
"""
Determines if guardrails are enabled and correctly configured.
Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys.
Checks if 'guardrails.trace' is true.
Returns:
bool: True if guardrails are correctly configured, False otherwise.
Raises:
TypeError: If 'guardrails' lacks 'id' or 'version' keys.
"""
try:
return (
isinstance(self.guardrails, dict)
and bool(self.guardrails["id"])
and bool(self.guardrails["version"])
)
except KeyError as e:
raise TypeError(
"Guardrails must be a dictionary with 'id' and 'version' keys."
) from e
def _get_guardrails_canonical(self) -> Dict[str, Any]:
"""
The canonical way to pass in guardrails to the bedrock service
adheres to the following format:
"amazon-bedrock-guardrailDetails": {
"guardrailId": "string",
"guardrailVersion": "string"
}
"""
return {
"amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"),
"guardrailVersion": self.guardrails.get("version"),
}
}
def _prepare_input_and_invoke(
self,
prompt: str,
@@ -309,29 +411,81 @@ class BedrockBase(BaseModel, ABC):
provider = self._get_provider()
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"
request_options = {
"body": body,
"modelId": self.model_id,
"accept": accept,
"contentType": contentType,
}
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
request_options["trace"] = "ENABLED"
try:
response = self.client.invoke_model(
body=body,
modelId=self.model_id,
accept=accept,
contentType=contentType,
)
text = LLMInputOutputAdapter.prepare_output(provider, response)
response = self.client.invoke_model(**request_options)
text, body = LLMInputOutputAdapter.prepare_output(
provider, response
).values()
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}").with_traceback(
e.__traceback__
)
raise ValueError(f"Error raised by bedrock service: {e}")
if stop is not None:
text = enforce_stop_tokens(text, stop)
# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
# such as when guardrails are triggered.
services_trace = self._get_bedrock_services_signal(body)
if services_trace.get("signal") and run_manager is not None:
run_manager.on_llm_error(
Exception(
f"Error raised by bedrock service: {services_trace.get('reason')}"
),
**services_trace,
)
return text
def _get_bedrock_services_signal(self, body: dict) -> dict:
"""
This function checks the response body for an interrupt flag or message that indicates
whether any of the Bedrock services have intervened in the processing flow. It is
primarily used to identify modifications or interruptions imposed by these services
during the request-response cycle with a Large Language Model (LLM).
""" # noqa: E501
if (
self._guardrails_enabled
and self.guardrails.get("trace")
and self._is_guardrails_intervention(body)
):
return {
"signal": True,
"reason": "GUARDRAIL_INTERVENED",
"trace": body.get(AMAZON_BEDROCK_TRACE_KEY),
}
return {
"signal": False,
"reason": None,
"trace": None,
}
def _is_guardrails_intervention(self, body: dict) -> bool:
return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED"
def _prepare_input_and_invoke_stream(
self,
prompt: str,
@@ -356,16 +510,28 @@ class BedrockBase(BaseModel, ABC):
_model_kwargs["stream"] = True
params = {**_model_kwargs, **kwargs}
if self._guardrails_enabled:
params.update(self._get_guardrails_canonical())
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
request_options = {
"body": body,
"modelId": self.model_id,
"accept": "application/json",
"contentType": "application/json",
}
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
request_options["trace"] = "ENABLED"
try:
response = self.client.invoke_model_with_response_stream(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
)
response = self.client.invoke_model_with_response_stream(**request_options)
except Exception as e:
raise ValueError(f"Error raised by bedrock service: {e}")
@@ -373,6 +539,9 @@ class BedrockBase(BaseModel, ABC):
provider, response, stop
):
yield chunk
# verify and raise callback error if any middleware intervened
self._get_bedrock_services_signal(chunk.generation_info)
if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
@@ -536,7 +705,9 @@ class Bedrock(LLM, BedrockBase):
completion += chunk.text
return completion
return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs)
return self._prepare_input_and_invoke(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
)
async def _astream(
self,