mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 11:02:37 +00:00
community[minor]: add support for Guardrails for Amazon Bedrock (#15099)
Added support for optionally supplying 'Guardrails for Amazon Bedrock' on both types of model invocations (batch/regular and streaming) and for all models supported by the Amazon Bedrock service. @baskaryan @hwchase17 ```python llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, guardrails={"id": " <guardrail_id>", "version": "<guardrail_version>", "trace": True}, callbacks=[BedrockAsyncCallbackHandler()]) class BedrockAsyncCallbackHandler(AsyncCallbackHandler): """Async callback handler that can be used to handle callbacks from langchain.""" async def on_llm_error( self, error: BaseException, **kwargs: Any, ) -> Any: reason = kwargs.get("reason") if reason == "GUARDRAIL_INTERVENED": # kwargs contains additional trace information sent by 'Guardrails for Bedrock' service. print(f"""Guardrails: {kwargs}""") # streaming llm = Bedrock(model_id="<model_id>", client=bedrock, model_kwargs={}, streaming=True, guardrails={"id": "<guardrail_id>", "version": "<guardrail_version>"}) ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
136
libs/community/tests/integration_tests/llms/test_bedrock.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Test Amazon Bedrock API wrapper and services i.e 'Guardrails for Amazon Bedrock'.
|
||||
You can get a list of models from the bedrock client by running 'bedrock_models()'
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from langchain_community.llms.bedrock import Bedrock
|
||||
|
||||
# this is the guardrails id for the model you want to test
|
||||
GUARDRAILS_ID = os.environ.get("GUARDRAILS_ID", "7jarelix77")
|
||||
# this is the guardrails version for the model you want to test
|
||||
GUARDRAILS_VERSION = os.environ.get("GUARDRAILS_VERSION", "1")
|
||||
# this should trigger the guardrails - you can change this to any text you want which
|
||||
# will trigger the guardrails
|
||||
GUARDRAILS_TRIGGER = os.environ.get(
|
||||
"GUARDRAILS_TRIGGERING_QUERY", "I want to talk about politics."
|
||||
)
|
||||
|
||||
|
||||
class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
|
||||
"""Async callback handler that can be used to handle callbacks from langchain."""
|
||||
|
||||
guardrails_intervened = False
|
||||
|
||||
async def on_llm_error(
|
||||
self,
|
||||
error: BaseException,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
reason = kwargs.get("reason")
|
||||
if reason == "GUARDRAIL_INTERVENED":
|
||||
self.guardrails_intervened = True
|
||||
|
||||
def get_response(self):
|
||||
return self.guardrails_intervened
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_runtime_client():
|
||||
import boto3
|
||||
|
||||
try:
|
||||
client = boto3.client(
|
||||
"bedrock-runtime",
|
||||
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not connect to bedrock-runtime client: {e}", pytrace=False)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def bedrock_client():
|
||||
import boto3
|
||||
|
||||
try:
|
||||
client = boto3.client(
|
||||
"bedrock",
|
||||
region_name=os.environ.get("AWS_REGION", "us-east-1"),
|
||||
)
|
||||
return client
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not connect to bedrock client: {e}", pytrace=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bedrock_models(bedrock_client):
|
||||
"""List bedrock models."""
|
||||
response = bedrock_client.list_foundation_models().get("modelSummaries")
|
||||
models = {}
|
||||
for model in response:
|
||||
models[model.get("modelId")] = model.get("modelName")
|
||||
return models
|
||||
|
||||
|
||||
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
)
|
||||
output = llm("Say something positive:")
|
||||
assert isinstance(output, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
try:
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": GUARDRAILS_ID,
|
||||
"version": GUARDRAILS_VERSION,
|
||||
"trace": False,
|
||||
},
|
||||
)
|
||||
output = llm("Say something positive:")
|
||||
assert isinstance(output, str)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
|
||||
|
||||
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
|
||||
bedrock_runtime_client, bedrock_models
|
||||
):
|
||||
try:
|
||||
handler = BedrockAsyncCallbackHandler()
|
||||
llm = Bedrock(
|
||||
model_id="anthropic.claude-instant-v1",
|
||||
client=bedrock_runtime_client,
|
||||
model_kwargs={},
|
||||
guardrails={
|
||||
"id": GUARDRAILS_ID,
|
||||
"version": GUARDRAILS_VERSION,
|
||||
"trace": True,
|
||||
},
|
||||
callbacks=[handler],
|
||||
)
|
||||
except Exception as e:
|
||||
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
|
||||
else:
|
||||
llm(GUARDRAILS_TRIGGER)
|
||||
guardrails_intervened = handler.get_response()
|
||||
assert guardrails_intervened is True
|
Reference in New Issue
Block a user