diff --git a/docs/docs/integrations/llms/bedrock.ipynb b/docs/docs/integrations/llms/bedrock.ipynb index d2d6cd2cd99..0c1748cc47f 100644 --- a/docs/docs/integrations/llms/bedrock.ipynb +++ b/docs/docs/integrations/llms/bedrock.ipynb @@ -106,6 +106,45 @@ "\n", "conversation.predict(input=\"Hi there!\")" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Guardrails for Amazon Bedrock example \n", + "\n", + "In this section, we are going to set up a Bedrock language model with specific guardrails that include tracing capabilities. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from langchain_core.callbacks import AsyncCallbackHandler\n", + "\n", + "\n", + "class BedrockAsyncCallbackHandler(AsyncCallbackHandler):\n", + " # Async callback handler that can be used to handle callbacks from langchain.\n", + "\n", + " async def on_llm_error(self, error: BaseException, **kwargs: Any) -> Any:\n", + " reason = kwargs.get(\"reason\")\n", + " if reason == \"GUARDRAIL_INTERVENED\":\n", + " print(f\"Guardrails: {kwargs}\")\n", + "\n", + "\n", + "# guardrails for Amazon Bedrock with trace\n", + "llm = Bedrock(\n", + " credentials_profile_name=\"bedrock-admin\",\n", + " model_id=\"\",\n", + " model_kwargs={},\n", + " guardrails={\"id\": \"\", \"version\": \"\", \"trace\": True},\n", + " callbacks=[BedrockAsyncCallbackHandler()],\n", + ")" + ] } ], "metadata": { diff --git a/libs/community/langchain_community/llms/bedrock.py b/libs/community/langchain_community/llms/bedrock.py index 5ec60e84967..c66d0d071a8 100644 --- a/libs/community/langchain_community/llms/bedrock.py +++ b/libs/community/langchain_community/llms/bedrock.py @@ -34,6 +34,8 @@ from langchain_community.utilities.anthropic import ( if TYPE_CHECKING: from botocore.config import Config +AMAZON_BEDROCK_TRACE_KEY = "amazon-bedrock-trace" +GUARDRAILS_BODY_KEY = "amazon-bedrock-guardrailAssessment" HUMAN_PROMPT = "\n\nHuman:" ASSISTANT_PROMPT = "\n\nAssistant:" ALTERNATION_ERROR = ( @@ -117,21 +119,26 @@ class LLMInputOutputAdapter: return input_body @classmethod - def prepare_output(cls, provider: str, response: Any) -> str: + def prepare_output(cls, provider: str, response: Any) -> dict: if provider == "anthropic": response_body = json.loads(response.get("body").read().decode()) - return response_body.get("completion") + text = response_body.get("completion") else: response_body = json.loads(response.get("body").read()) - if provider == "ai21": - return response_body.get("completions")[0].get("data").get("text") - elif provider == "cohere": - return response_body.get("generations")[0].get("text") - elif provider == "meta": - return response_body.get("generation") - else: - return response_body.get("results")[0].get("outputText") + if provider == "ai21": + text = response_body.get("completions")[0].get("data").get("text") + elif provider == "cohere": + text = response_body.get("generations")[0].get("text") + elif provider == "meta": + text = response_body.get("generation") + else: + text = response_body.get("results")[0].get("outputText") + + return { + "text": text, + "body": response_body, + } @classmethod def prepare_output_stream( @@ -160,8 +167,15 @@ class LLMInputOutputAdapter: chunk_obj["is_finished"] or chunk_obj[output_key] == "" ): return - - yield GenerationChunk(text=chunk_obj[output_key]) + # chunk obj format varies with provider + yield GenerationChunk( + text=chunk_obj[output_key], + generation_info={ + GUARDRAILS_BODY_KEY: chunk_obj.get(GUARDRAILS_BODY_KEY) + if GUARDRAILS_BODY_KEY in chunk_obj + else None, + }, + ) @classmethod async def aprepare_output_stream( @@ -235,6 +249,53 @@ class BedrockBase(BaseModel, ABC): "cohere": "stop_sequences", } + guardrails: Optional[Mapping[str, Any]] = { + "id": None, + "version": None, + "trace": False, + } + """ + An optional dictionary to configure guardrails for Bedrock. + + This field 'guardrails' consists of two keys: 'id' and 'version', + which should be strings, but are initialized to None. It's used to + determine if specific guardrails are enabled and properly set. + + Type: + Optional[Mapping[str, str]]: A mapping with 'id' and 'version' keys. + + Example: + llm = Bedrock(model_id="", client=, + model_kwargs={}, + guardrails={ + "id": "", + "version": ""}) + + To enable tracing for guardrails, set the 'trace' key to True and pass a callback handler to the + 'run_manager' parameter of the 'generate', '_call' methods. + + Example: + llm = Bedrock(model_id="", client=, + model_kwargs={}, + guardrails={ + "id": "", + "version": "", + "trace": True}, + callbacks=[BedrockAsyncCallbackHandler()]) + + [https://python.langchain.com/docs/modules/callbacks/] for more information on callback handlers. + + class BedrockAsyncCallbackHandler(AsyncCallbackHandler): + async def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> Any: + reason = kwargs.get("reason") + if reason == "GUARDRAIL_INTERVENED": + ...Logic to handle guardrail intervention... + """ # noqa: E501 + @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that AWS credentials to and python package exists in environment.""" @@ -298,6 +359,47 @@ class BedrockBase(BaseModel, ABC): def _model_is_anthropic(self) -> bool: return self._get_provider() == "anthropic" + @property + def _guardrails_enabled(self) -> bool: + """ + Determines if guardrails are enabled and correctly configured. + Checks if 'guardrails' is a dictionary with non-empty 'id' and 'version' keys. + Checks if 'guardrails.trace' is true. + + Returns: + bool: True if guardrails are correctly configured, False otherwise. + Raises: + TypeError: If 'guardrails' lacks 'id' or 'version' keys. + """ + try: + return ( + isinstance(self.guardrails, dict) + and bool(self.guardrails["id"]) + and bool(self.guardrails["version"]) + ) + + except KeyError as e: + raise TypeError( + "Guardrails must be a dictionary with 'id' and 'version' keys." + ) from e + + def _get_guardrails_canonical(self) -> Dict[str, Any]: + """ + The canonical way to pass in guardrails to the bedrock service + adheres to the following format: + + "amazon-bedrock-guardrailDetails": { + "guardrailId": "string", + "guardrailVersion": "string" + } + """ + return { + "amazon-bedrock-guardrailDetails": { + "guardrailId": self.guardrails.get("id"), + "guardrailVersion": self.guardrails.get("version"), + } + } + def _prepare_input_and_invoke( self, prompt: str, @@ -309,29 +411,81 @@ class BedrockBase(BaseModel, ABC): provider = self._get_provider() params = {**_model_kwargs, **kwargs} + if self._guardrails_enabled: + params.update(self._get_guardrails_canonical()) input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) body = json.dumps(input_body) accept = "application/json" contentType = "application/json" + request_options = { + "body": body, + "modelId": self.model_id, + "accept": accept, + "contentType": contentType, + } + + if self._guardrails_enabled: + request_options["guardrail"] = "ENABLED" + if self.guardrails.get("trace"): + request_options["trace"] = "ENABLED" + try: - response = self.client.invoke_model( - body=body, - modelId=self.model_id, - accept=accept, - contentType=contentType, - ) - text = LLMInputOutputAdapter.prepare_output(provider, response) + response = self.client.invoke_model(**request_options) + + text, body = LLMInputOutputAdapter.prepare_output( + provider, response + ).values() + except Exception as e: - raise ValueError(f"Error raised by bedrock service: {e}").with_traceback( - e.__traceback__ - ) + raise ValueError(f"Error raised by bedrock service: {e}") if stop is not None: text = enforce_stop_tokens(text, stop) + # Verify and raise a callback error if any intervention occurs or a signal is + # sent from a Bedrock service, + # such as when guardrails are triggered. + services_trace = self._get_bedrock_services_signal(body) + + if services_trace.get("signal") and run_manager is not None: + run_manager.on_llm_error( + Exception( + f"Error raised by bedrock service: {services_trace.get('reason')}" + ), + **services_trace, + ) + return text + def _get_bedrock_services_signal(self, body: dict) -> dict: + """ + This function checks the response body for an interrupt flag or message that indicates + whether any of the Bedrock services have intervened in the processing flow. It is + primarily used to identify modifications or interruptions imposed by these services + during the request-response cycle with a Large Language Model (LLM). + """ # noqa: E501 + + if ( + self._guardrails_enabled + and self.guardrails.get("trace") + and self._is_guardrails_intervention(body) + ): + return { + "signal": True, + "reason": "GUARDRAIL_INTERVENED", + "trace": body.get(AMAZON_BEDROCK_TRACE_KEY), + } + + return { + "signal": False, + "reason": None, + "trace": None, + } + + def _is_guardrails_intervention(self, body: dict) -> bool: + return body.get(GUARDRAILS_BODY_KEY) == "GUARDRAIL_INTERVENED" + def _prepare_input_and_invoke_stream( self, prompt: str, @@ -356,16 +510,28 @@ class BedrockBase(BaseModel, ABC): _model_kwargs["stream"] = True params = {**_model_kwargs, **kwargs} + + if self._guardrails_enabled: + params.update(self._get_guardrails_canonical()) + input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) body = json.dumps(input_body) + request_options = { + "body": body, + "modelId": self.model_id, + "accept": "application/json", + "contentType": "application/json", + } + + if self._guardrails_enabled: + request_options["guardrail"] = "ENABLED" + if self.guardrails.get("trace"): + request_options["trace"] = "ENABLED" + try: - response = self.client.invoke_model_with_response_stream( - body=body, - modelId=self.model_id, - accept="application/json", - contentType="application/json", - ) + response = self.client.invoke_model_with_response_stream(**request_options) + except Exception as e: raise ValueError(f"Error raised by bedrock service: {e}") @@ -373,6 +539,9 @@ class BedrockBase(BaseModel, ABC): provider, response, stop ): yield chunk + # verify and raise callback error if any middleware intervened + self._get_bedrock_services_signal(chunk.generation_info) + if run_manager is not None: run_manager.on_llm_new_token(chunk.text, chunk=chunk) @@ -536,7 +705,9 @@ class Bedrock(LLM, BedrockBase): completion += chunk.text return completion - return self._prepare_input_and_invoke(prompt=prompt, stop=stop, **kwargs) + return self._prepare_input_and_invoke( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ) async def _astream( self, diff --git a/libs/community/tests/integration_tests/llms/test_bedrock.py b/libs/community/tests/integration_tests/llms/test_bedrock.py new file mode 100644 index 00000000000..45a7bcd0bfa --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_bedrock.py @@ -0,0 +1,136 @@ +""" +Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'. +You can get a list of models from the bedrock client by running 'bedrock_models()' + +""" + +import os +from typing import Any + +import pytest +from langchain_core.callbacks import AsyncCallbackHandler + +from langchain_community.llms.bedrock import Bedrock + +# this is the guardrails id for the model you want to test +GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77") +# this is the guardrails version for the model you want to test +GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1") +# this should trigger the guardrails - you can change this to any text you want which +# will trigger the guardrails +GUARDRAILS_TRIGGER = os.environ.get( + "GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics." +) + + +class BedrockAsyncCallbackHandler(AsyncCallbackHandler): + """Async callback handler that can be used to handle callbacks from langchain.""" + + guardrails_intervened = False + + async def on_llm_error( + self, + error: BaseException, + **kwargs: Any, + ) -> Any: + reason = kwargs.get("reason") + if reason == "GUARDRAIL_INTERVENED": + self.guardrails_intervened = True + + def get_response(self): + return self.guardrails_intervened + + +@pytest.fixture(autouse=True) +def bedrock_runtime_client(): + import boto3 + + try: + client = boto3.client( + "bedrock-runtime", + region_name=os.environ.get("AWS_REGION", "us-east-1"), + ) + return client + except Exception as e: + pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False) + + +@pytest.fixture(autouse=True) +def bedrock_client(): + import boto3 + + try: + client = boto3.client( + "bedrock", + region_name=os.environ.get("AWS_REGION", "us-east-1"), + ) + return client + except Exception as e: + pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False) + + +@pytest.fixture +def bedrock_models(bedrock_client): + """List bedrock models.""" + response = bedrock_client.list_foundation_models().get("modelSummaries") + models = {} + for model in response: + models[model.get("modelId")] = model.get("modelName") + return models + + +def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): + try: + llm = Bedrock( + model_id="anthropic.claude-instant-v1", + client=bedrock_runtime_client, + model_kwargs={}, + ) + output = llm("Say something positive:") + assert isinstance(output, str) + except Exception as e: + pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) + + +def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( + bedrock_runtime_client, bedrock_models +): + try: + llm = Bedrock( + model_id="anthropic.claude-instant-v1", + client=bedrock_runtime_client, + model_kwargs={}, + guardrails={ + "id": GUARDRAILS_ID, + "version": GUARDRAILS_VERSION, + "trace": False, + }, + ) + output = llm("Say something positive:") + assert isinstance(output, str) + except Exception as e: + pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) + + +def test_amazon_bedrock_guardrails_intervention_for_invalid_query( + bedrock_runtime_client, bedrock_models +): + try: + handler = BedrockAsyncCallbackHandler() + llm = Bedrock( + model_id="anthropic.claude-instant-v1", + client=bedrock_runtime_client, + model_kwargs={}, + guardrails={ + "id": GUARDRAILS_ID, + "version": GUARDRAILS_VERSION, + "trace": True, + }, + callbacks=[handler], + ) + except Exception as e: + pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False) + else: + llm(GUARDRAILS_TRIGGER) + guardrails_intervened = handler.get_response() + assert guardrails_intervened is True