guards changes

This commit is contained in:
Harrison Chase
2023-03-18 21:59:15 -07:00
parent 6bf2b70331
commit a20ff254e5
8 changed files with 105 additions and 90 deletions

View File

@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "d64f114c",
"metadata": {},
"outputs": [],
@@ -45,12 +45,13 @@
"from langchain.chains import LLMChain\n",
"from langchain.llms import OpenAI\n",
"from langchain.output_parsers import PydanticOutputParser\n",
"from langchain.prompts import PromptTemplate"
"from langchain.prompts import PromptTemplate\n",
"from langchain.guardrails import FormatInstructionsGuard"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"id": "ae3ec42d",
"metadata": {},
"outputs": [],
@@ -80,7 +81,7 @@
" prompt=prompt,\n",
" llm=OpenAI(model_name=\"text-curie-001\"), # Use a smaller model.\n",
" output_parser=parser,\n",
" output_parser_retry_enabled=True, # For retry, escalate to a larger model. We use DaVinci zero-shot naively.\n",
" output_parser_guard=FormatInstructionsGuard.from_llm(OpenAI(model_name='text-davinci-003')),\n",
" verbose=True)"
]
},
@@ -107,8 +108,6 @@
"```\n",
"Write out a few terms of fiboacci.\n",
"\u001b[0m\n",
"Uh-oh! Got Failed to parse FloatArray from completion \n",
"1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144.. Got: Expecting value: line 1 column 1 (char 0). Retrying with DaVinci.\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n"
]
@@ -116,7 +115,7 @@
{
"data": {
"text/plain": [
"FloatArray(values=[1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0, 144.0])"
"FloatArray(values=[0.0, 1.0, 1.0, 2.0, 3.0, 5.0, 8.0, 13.0, 21.0, 34.0, 55.0, 89.0, 144.0, 233.0, 377.0, 610.0, 987.0, 1597.0, 2584.0, 4181.0, 6765.0])"
]
},
"execution_count": 6,
@@ -154,7 +153,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"id": "362678a4",
"metadata": {},
"outputs": [
@@ -185,10 +184,10 @@
{
"data": {
"text/plain": [
"'LOL\\nText: \"I was really happy with the gift!\" Label: Positive\\nText: \"I am unhappy because of the rain.\" Label: Negative\\nText: \"I am excited to eat ice cream on Sunday\" Label: Positive'"
"'LOL\\nText: \"I was really happy with the gift!\"\\nLabel: Positive\\nText: \"I am unhappy because of the rain.\"\\nLabel: Negative\\nText: \"I am excited to eat ice cream on Sunday\"\\nLabel: Positive'"
]
},
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -232,17 +231,17 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 10,
"id": "0291a831",
"metadata": {},
"outputs": [],
"source": [
"from langchain.guardrails import Guardrail, GuardrailEvaluation\n",
"from langchain.schema import ModelOutputGuard\n",
"from typing import Tuple\n",
"\n",
"from distutils.util import strtobool\n",
"\n",
"class PromptLeakageGuardrail(Guardrail):\n",
"class PromptLeakageGuardrail(ModelOutputGuard):\n",
" \n",
" language_model = OpenAI(model_name='text-davinci-003') \n",
" prompt_template = PromptTemplate(\n",
@@ -250,20 +249,22 @@
" input_variables=[\"prompt\", \"completion\"] \n",
" )\n",
"\n",
" def evaluate(self, prompt: str, completion: str) -> Tuple[Guardrail, bool]:\n",
" def evaluate(self, prompt: str, completion: str) -> str:\n",
" prompt = self.prompt_template.format_prompt(prompt=prompt, completion=completion).to_string()\n",
" completion = self.language_model(prompt)\n",
" binary_value = strtobool(completion.strip().split()[0])\n",
" if binary_value:\n",
" return GuardrailEvaluation(error_msg = 'Prompt leaked!'), False\n",
" return None, True"
" raise ValueError\n",
" return completion"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 12,
"id": "89bf0615",
"metadata": {},
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
@@ -288,22 +289,23 @@
]
},
{
"ename": "RuntimeError",
"evalue": "Prompt leaked!",
"ename": "ValueError",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[6], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m llm_chain \u001b[38;5;241m=\u001b[39m LLMChain(\n\u001b[1;32m 2\u001b[0m prompt\u001b[38;5;241m=\u001b[39mprompt_to_leak,\n\u001b[1;32m 3\u001b[0m llm\u001b[38;5;241m=\u001b[39mOpenAI(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-davinci-003\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m guardrails\u001b[38;5;241m=\u001b[39m[PromptLeakageGuardrail()],\n\u001b[1;32m 5\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mllm_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43munseen_example\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madversarial_instruction\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:201\u001b[0m, in \u001b[0;36mLLMChain.predict\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpredict\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 188\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Format prompt with kwargs and pass to LLM.\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \n\u001b[1;32m 190\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;124;03m completion = llm.predict(adjective=\"funny\")\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 201\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key]\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:67\u001b[0m, in \u001b[0;36mLLMChain._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]:\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:133\u001b[0m, in \u001b[0;36mLLMChain.apply\u001b[0;34m(self, input_list)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Utilize the LLM generate method for speed gains.\"\"\"\u001b[39;00m\n\u001b[1;32m 132\u001b[0m response, prompts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerate(input_list)\n\u001b[0;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_outputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:178\u001b[0m, in \u001b[0;36mLLMChain.create_outputs\u001b[0;34m(self, response, prompts)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_outputs\u001b[39m(\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28mself\u001b[39m, response: LLMResult, prompts: List[PromptValue]\n\u001b[1;32m 176\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]]:\n\u001b[1;32m 177\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create outputs from response.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 178\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 179\u001b[0m \u001b[38;5;66;03m# Get the text of the top generated string.\u001b[39;00m\n\u001b[1;32m 180\u001b[0m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_final_output(generation[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mtext, prompts[i])}\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, generation \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(response\u001b[38;5;241m.\u001b[39mgenerations)\n\u001b[1;32m 182\u001b[0m ]\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:180\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_outputs\u001b[39m(\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28mself\u001b[39m, response: LLMResult, prompts: List[PromptValue]\n\u001b[1;32m 176\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]]:\n\u001b[1;32m 177\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create outputs from response.\"\"\"\u001b[39;00m\n\u001b[1;32m 178\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 179\u001b[0m \u001b[38;5;66;03m# Get the text of the top generated string.\u001b[39;00m\n\u001b[0;32m--> 180\u001b[0m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_final_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m}\n\u001b[1;32m 181\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, generation \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(response\u001b[38;5;241m.\u001b[39mgenerations)\n\u001b[1;32m 182\u001b[0m ]\n",
"File \u001b[0;32m~/src/langchain/langchain/chains/llm.py:151\u001b[0m, in \u001b[0;36mLLMChain._get_final_output\u001b[0;34m(self, completion, prompt_value)\u001b[0m\n\u001b[1;32m 147\u001b[0m evaluation, ok \u001b[38;5;241m=\u001b[39m guardrail\u001b[38;5;241m.\u001b[39mevaluate(prompt_value, completion)\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m ok:\n\u001b[1;32m 149\u001b[0m \u001b[38;5;66;03m# TODO: consider associating customer exception w/ guardrail\u001b[39;00m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;66;03m# as suggested in https://github.com/hwchase17/langchain/pull/1683/files#r1139987185\u001b[39;00m\n\u001b[0;32m--> 151\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(evaluation\u001b[38;5;241m.\u001b[39merror_msg)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m evaluation\u001b[38;5;241m.\u001b[39mrevised_output:\n\u001b[1;32m 153\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(evaluation\u001b[38;5;241m.\u001b[39mrevised_output, \u001b[38;5;28mstr\u001b[39m)\n",
"\u001b[0;31mRuntimeError\u001b[0m: Prompt leaked!"
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 6\u001b[0m\n\u001b[1;32m 1\u001b[0m llm_chain \u001b[38;5;241m=\u001b[39m LLMChain(\n\u001b[1;32m 2\u001b[0m prompt\u001b[38;5;241m=\u001b[39mprompt_to_leak,\n\u001b[1;32m 3\u001b[0m llm\u001b[38;5;241m=\u001b[39mOpenAI(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext-davinci-003\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m guards\u001b[38;5;241m=\u001b[39m[PromptLeakageGuardrail()],\n\u001b[1;32m 5\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m----> 6\u001b[0m \u001b[43mllm_chain\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpredict\u001b[49m\u001b[43m(\u001b[49m\u001b[43munseen_example\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madversarial_instruction\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:188\u001b[0m, in \u001b[0;36mLLMChain.predict\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpredict\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mstr\u001b[39m:\n\u001b[1;32m 175\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Format prompt with kwargs and pass to LLM.\u001b[39;00m\n\u001b[1;32m 176\u001b[0m \n\u001b[1;32m 177\u001b[0m \u001b[38;5;124;03m Args:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;124;03m completion = llm.predict(adjective=\"funny\")\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 188\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key]\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/base.py:116\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_end(outputs, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n\u001b[1;32m 118\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprep_outputs(inputs, outputs, return_only_outputs)\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/base.py:113\u001b[0m, in \u001b[0;36mChain.__call__\u001b[0;34m(self, inputs, return_only_outputs)\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_start(\n\u001b[1;32m 108\u001b[0m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m},\n\u001b[1;32m 109\u001b[0m inputs,\n\u001b[1;32m 110\u001b[0m verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[1;32m 111\u001b[0m )\n\u001b[1;32m 112\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 113\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m (\u001b[38;5;167;01mKeyboardInterrupt\u001b[39;00m, \u001b[38;5;167;01mException\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_manager\u001b[38;5;241m.\u001b[39mon_chain_error(e, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose)\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:67\u001b[0m, in \u001b[0;36mLLMChain._call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 66\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_call\u001b[39m(\u001b[38;5;28mself\u001b[39m, inputs: Dict[\u001b[38;5;28mstr\u001b[39m, Any]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]:\n\u001b[0;32m---> 67\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:133\u001b[0m, in \u001b[0;36mLLMChain.apply\u001b[0;34m(self, input_list)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Utilize the LLM generate method for speed gains.\"\"\"\u001b[39;00m\n\u001b[1;32m 132\u001b[0m response, prompts \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgenerate(input_list)\n\u001b[0;32m--> 133\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcreate_outputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresponse\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:165\u001b[0m, in \u001b[0;36mLLMChain.create_outputs\u001b[0;34m(self, response, prompts)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_outputs\u001b[39m(\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28mself\u001b[39m, response: LLMResult, prompts: List[PromptValue]\n\u001b[1;32m 163\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]]:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create outputs from response.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 165\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# Get the text of the top generated string.\u001b[39;00m\n\u001b[1;32m 167\u001b[0m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_final_output(generation[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mtext, prompts[i])}\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, generation \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(response\u001b[38;5;241m.\u001b[39mgenerations)\n\u001b[1;32m 169\u001b[0m ]\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:167\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcreate_outputs\u001b[39m(\n\u001b[1;32m 162\u001b[0m \u001b[38;5;28mself\u001b[39m, response: LLMResult, prompts: List[PromptValue]\n\u001b[1;32m 163\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m List[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]]:\n\u001b[1;32m 164\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Create outputs from response.\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m [\n\u001b[1;32m 166\u001b[0m \u001b[38;5;66;03m# Get the text of the top generated string.\u001b[39;00m\n\u001b[0;32m--> 167\u001b[0m {\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_key: \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_final_output\u001b[49m\u001b[43m(\u001b[49m\u001b[43mgeneration\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m[\u001b[49m\u001b[43mi\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m}\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m i, generation \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(response\u001b[38;5;241m.\u001b[39mgenerations)\n\u001b[1;32m 169\u001b[0m ]\n",
"File \u001b[0;32m~/workplace/langchain/langchain/chains/llm.py:149\u001b[0m, in \u001b[0;36mLLMChain._get_final_output\u001b[0;34m(self, completion, prompt_value)\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Validate raw completion (guardrails) + extract structured data from it (parser).\u001b[39;00m\n\u001b[1;32m 142\u001b[0m \n\u001b[1;32m 143\u001b[0m \u001b[38;5;124;03mWe may want to apply guardrails not just to raw string completions, but also to structured parsed completions.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;124;03mTODO: actually, reasonable to do the parsing first so that guardrail.evaluate gets the reified data structure.\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 148\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m guard \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mguards:\n\u001b[0;32m--> 149\u001b[0m completion \u001b[38;5;241m=\u001b[39m \u001b[43mguard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mevaluate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprompt_value\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompletion\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_parser:\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_parser_guard:\n",
"Cell \u001b[0;32mIn[10], line 19\u001b[0m, in \u001b[0;36mPromptLeakageGuardrail.evaluate\u001b[0;34m(self, prompt, completion)\u001b[0m\n\u001b[1;32m 17\u001b[0m binary_value \u001b[38;5;241m=\u001b[39m strtobool(completion\u001b[38;5;241m.\u001b[39mstrip()\u001b[38;5;241m.\u001b[39msplit()[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m binary_value:\n\u001b[0;32m---> 19\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m completion\n",
"\u001b[0;31mValueError\u001b[0m: "
]
}
],
@@ -311,7 +313,7 @@
"llm_chain = LLMChain(\n",
" prompt=prompt_to_leak,\n",
" llm=OpenAI(model_name=\"text-davinci-003\"),\n",
" guardrails=[PromptLeakageGuardrail()],\n",
" guards=[PromptLeakageGuardrail()],\n",
" verbose=True)\n",
"llm_chain.predict(unseen_example=adversarial_instruction)"
]
@@ -363,7 +365,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@@ -13,10 +13,10 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLanguageModel,
LLMResult,
ModelOutputGuard,
ModelOutputParserGuard,
PromptValue,
)
from langchain.guardrails import Guardrail
from langchain.guardrails.utils import dumb_davinci_retry
class LLMChain(Chain, BaseModel):
@@ -38,8 +38,8 @@ class LLMChain(Chain, BaseModel):
llm: BaseLanguageModel
output_key: str = "text" #: :meta private:
output_parser: Optional[BaseOutputParser] = None
output_parser_retry_enabled: bool = False
guardrails: List[Guardrail] = Field(default_factory=list)
guards: List[ModelOutputGuard] = Field(default_factory=list)
output_parser_guard: Optional[ModelOutputParserGuard] = None
class Config:
"""Configuration for this pydantic object."""
@@ -145,34 +145,19 @@ class LLMChain(Chain, BaseModel):
TODO: actually, reasonable to do the parsing first so that guardrail.evaluate gets the reified data structure.
"""
for guardrail in self.guardrails:
evaluation, ok = guardrail.evaluate(prompt_value, completion)
if evaluation.revised_output:
assert isinstance(evaluation.revised_output, str)
completion = evaluation.revised_output
elif not ok and not evaluation.revised_output:
# TODO: consider associating customer exception w/ guardrail
# as suggested in https://github.com/hwchase17/langchain/pull/1683/files#r1139987185
raise RuntimeError(evaluation.error_msg)
for guard in self.guards:
completion = guard.evaluate(prompt_value, completion)
if self.output_parser:
try:
parsed_completion = self.output_parser.parse(completion)
except OutputParserException as e:
if self.output_parser_retry_enabled:
_text = f"Uh-oh! Got {e}. Retrying with DaVinci."
self.callback_manager.on_text(_text, end="\n", verbose=self.verbose)
retried_completion = dumb_davinci_retry(prompt_value.to_string(), completion)
parsed_completion = self.output_parser.parse(retried_completion)
else:
raise e
completion = parsed_completion
if self.output_parser_guard:
completion = self.output_parser_guard.evaluate(
prompt_value, completion, self.output_parser
)
else:
completion = self.output_parser.parse(completion)
return completion
def create_outputs(
self, response: LLMResult, prompts: List[PromptValue]
) -> List[Dict[str, str]]:

View File

@@ -1 +1 @@
from langchain.guardrails.base import Guardrail, GuardrailEvaluation
from langchain.guardrails.format_instructions import FormatInstructionsGuard

View File

@@ -1,23 +1 @@
from abc import ABC, abstractmethod
from typing import Any, Optional, Tuple
from pydantic import BaseModel
class GuardrailEvaluation(BaseModel):
"""Hm want to encapsulate the result of applying a guardrail
"""
error_msg: str # Indicate why initial output validation failed.
revised_output: Optional[Any] # Optionally, try to fix the output.
class Guardrail(ABC, BaseModel):
@abstractmethod
def evaluate(self, input: Any, output: Any) -> Tuple[Optional[GuardrailEvaluation], bool]:
"""A generic guardrail on any function (a function that gets human input, an LM call, a chain, an agent, etc.)
is evaluated against that function's input and output.
Evaluation includes a validation/verification step. It may also include a retry to generate a satisfactory revised output.
These steps are encapsulated jointly, as a single LM call may succeed in both.
"""

View File

@@ -0,0 +1,34 @@
from __future__ import annotations
from typing import Any
from langchain.chains.llm import LLMChain
from langchain.output_parsers.base import BaseOutputParser, OutputParserException
from langchain.prompts.base import PromptValue
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLanguageModel, ModelOutputParserGuard
TEMPLATE = "Prompt:\n{prompt}\nCompletion:\n{completion}\n\nAbove, the Completion did not satisfy the constraints given in the Prompt. Please try again:"
PROMPT = PromptTemplate.from_template(TEMPLATE)
class FormatInstructionsGuard(ModelOutputParserGuard):
fixer_chain: LLMChain
@classmethod
def from_llm(cls, llm: BaseLanguageModel) -> FormatInstructionsGuard:
return cls(fixer_chain=LLMChain(llm=llm, prompt=PROMPT))
def evaluate(
self, prompt_value: PromptValue, output: str, output_parser: BaseOutputParser
) -> Any:
try:
result = output_parser.parse(output)
except OutputParserException as e:
new_result = self.fixer_chain.run(
prompt=prompt_value.to_string(), completion=output
)
result = output_parser.parse(new_result)
return result

View File

@@ -4,13 +4,13 @@ from langchain.prompts.prompt import PromptTemplate
# TODO: perhaps prompt str -> PromptValue
def dumb_davinci_retry(prompt: str, completion: str) -> str:
"""Big model go brrrr.
"""
davinci = OpenAI(model_name='text-davinci-003', temperature=0.5)
"""Big model go brrrr."""
davinci = OpenAI(model_name="text-davinci-003", temperature=0.5)
retry_prompt = PromptTemplate(
template="Prompt:\n{prompt}\nCompletion:\n{completion}\n\nAbove, the Completion did not satisfy the constraints given in the Prompt. Please try again:",
input_variables=["prompt", "completion"]
input_variables=["prompt", "completion"],
)
retry_prompt_str = retry_prompt.format_prompt(prompt=prompt, completion=completion).to_string()
retry_prompt_str = retry_prompt.format_prompt(
prompt=prompt, completion=completion
).to_string()
return davinci(retry_prompt_str)

View File

@@ -29,4 +29,4 @@ class BaseOutputParser(BaseModel, ABC):
class OutputParserException(Exception):
pass
pass

View File

@@ -6,6 +6,8 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.output_parsers import BaseOutputParser
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
@@ -244,3 +246,17 @@ class BaseMemory(BaseModel, ABC):
Memory = BaseMemory
class ModelOutputGuard(ABC, BaseModel):
@abstractmethod
def evaluate(self, prompt_value: PromptValue, output: str) -> str:
"""Evaluate and fix model output. Should still return a string."""
class ModelOutputParserGuard(ABC, BaseModel):
@abstractmethod
def evaluate(
self, prompt_value: PromptValue, output: str, output_parser: BaseOutputParser
) -> Any:
"""Evaluate and fix model output. Should parse output."""