feat: add custom prompt for QAEvalChain chain (#610)

I originally had only modified the `from_llm` to include the prompt but
I realized that if the prompt keys used on the custom prompt didn't
match the default prompt, it wouldn't work because of how `apply` works.

So I made some changes to the evaluate method to check if the prompt is
the default and if not, it will check if the input keys are the same as
the prompt key and update the inputs appropriately.

Let me know if there is a better way to do this.

Also added the custom prompt to the QA eval notebook.
This commit is contained in:
Nicolas
2023-01-14 12:23:48 -03:00
committed by GitHub
parent 1787c473b8
commit 91d7fd20ae
2 changed files with 73 additions and 5 deletions

View File

@@ -190,6 +190,51 @@
" print()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "782ae8c8",
"metadata": {},
"source": [
"## Customize Prompt\n",
"\n",
"You can also customize the prompt that is used. Here is an example prompting it using a score from 0 to 10.\n",
"The custom prompt requires 3 input variables: \"query\", \"answer\" and \"result\". Where \"query\" is the question, \"answer\" is the ground truth answer, and \"result\" is the predicted answer."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "153425c4",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts.prompt import PromptTemplate\n",
"\n",
"_PROMPT_TEMPLATE = \"\"\"You are an expert professor specialized in grading students' answers to questions.\n",
"You are grading the following question:\n",
"{query}\n",
"Here is the real answer:\n",
"{answer}\n",
"You are grading the following predicted answer:\n",
"{result}\n",
"What grade do you give from 0 to 10, where 0 is the lowest (very low similarity) and 10 is the highest (very high similarity)?\n",
"\"\"\"\n",
"\n",
"PROMPT = PromptTemplate(input_variables=[\"query\", \"answer\", \"result\"], template=_PROMPT_TEMPLATE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a3b0fb7",
"metadata": {},
"outputs": [],
"source": [
"evalchain = QAEvalChain.from_llm(llm=llm,prompt=PROMPT)\n",
"evalchain.evaluate(examples, predictions, question_key=\"question\", answer_key=\"answer\", prediction_key=\"text\")"
]
},
{
"cell_type": "markdown",
"id": "aaa61f0c",
@@ -271,7 +316,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@@ -285,7 +330,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.9.7 (default, Sep 16 2021, 08:50:36) \n[Clang 10.0.0 ]"
},
"vscode": {
"interpreter": {
"hash": "53f3bc57609c7a84333bb558594977aa5b4026b1d6070b93987956689e367341"
}
}
},
"nbformat": 4,