prompt template from string (#884)

This commit is contained in:
Harrison Chase 2023-02-04 17:04:58 -08:00 committed by GitHub
parent 7cc44b3bdb
commit a2b699dcd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 7 deletions

View File

@ -151,6 +151,47 @@
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")" "multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "72f32ff2",
"metadata": {},
"source": [
"## From Template\n",
"You can also easily load a prompt template by just specifying the template, and not worrying about the input variables."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2a81f2f8",
"metadata": {},
"outputs": [],
"source": [
"template = \"Tell me a {adjective} joke about {content}.\"\n",
"multiple_input_prompt = PromptTemplate.from_template(template)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d365b144",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"PromptTemplate(input_variables=['adjective', 'content'], output_parser=None, template='Tell me a {adjective} joke about {content}.', template_format='f-string', validate_template=True)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"multiple_input_prompt"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "b2dd6154", "id": "b2dd6154",

View File

@ -1,5 +1,4 @@
"""Chain that just formats a prompt and calls an LLM.""" """Chain that just formats a prompt and calls an LLM."""
from string import Formatter
from typing import Any, Dict, List, Sequence, Union from typing import Any, Dict, List, Sequence, Union
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
@ -132,10 +131,5 @@ class LLMChain(Chain, BaseModel):
@classmethod @classmethod
def from_string(cls, llm: BaseLLM, template: str) -> Chain: def from_string(cls, llm: BaseLLM, template: str) -> Chain:
"""Create LLMChain from LLM and template.""" """Create LLMChain from LLM and template."""
input_variables = { prompt_template = PromptTemplate.from_template(template)
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
prompt_template = PromptTemplate(
input_variables=list(input_variables), template=template
)
return cls(llm=llm, prompt=prompt_template) return cls(llm=llm, prompt=prompt_template)

View File

@ -1,6 +1,7 @@
"""Prompt schema definition.""" """Prompt schema definition."""
from __future__ import annotations from __future__ import annotations
from string import Formatter
from typing import Any, Dict, List from typing import Any, Dict, List
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
@ -117,6 +118,14 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
template = f.read() template = f.read()
return cls(input_variables=input_variables, template=template) return cls(input_variables=input_variables, template=template)
@classmethod
def from_template(cls, template: str) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
return cls(input_variables=list(input_variables), template=template)
# For backwards compatibility. # For backwards compatibility.
Prompt = PromptTemplate Prompt = PromptTemplate

View File

@ -13,6 +13,27 @@ def test_prompt_valid() -> None:
assert prompt.input_variables == input_variables assert prompt.input_variables == input_variables
def test_prompt_from_template() -> None:
"""Test prompts can be constructed from a template."""
# Single input variable.
template = "This is a {foo} test."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["foo"])
assert prompt == expected_prompt
# Multiple input variables.
template = "This {bar} is a {foo} test."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
assert prompt == expected_prompt
# Multiple input variables with repeats.
template = "This {bar} is a {foo} test {foo}."
prompt = PromptTemplate.from_template(template)
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
assert prompt == expected_prompt
def test_prompt_missing_input_variables() -> None: def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided.""" """Test error is raised when input variables are not provided."""
template = "This is a {foo} test." template = "This is a {foo} test."