mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 08:03:39 +00:00
prompt template from string (#884)
This commit is contained in:
parent
7cc44b3bdb
commit
a2b699dcd2
@ -151,6 +151,47 @@
|
|||||||
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
|
"multiple_input_prompt.format(adjective=\"funny\", content=\"chickens\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "72f32ff2",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## From Template\n",
|
||||||
|
"You can also easily load a prompt template by just specifying the template, and not worrying about the input variables."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "2a81f2f8",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"template = \"Tell me a {adjective} joke about {content}.\"\n",
|
||||||
|
"multiple_input_prompt = PromptTemplate.from_template(template)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "d365b144",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"PromptTemplate(input_variables=['adjective', 'content'], output_parser=None, template='Tell me a {adjective} joke about {content}.', template_format='f-string', validate_template=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"multiple_input_prompt"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "b2dd6154",
|
"id": "b2dd6154",
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
"""Chain that just formats a prompt and calls an LLM."""
|
"""Chain that just formats a prompt and calls an LLM."""
|
||||||
from string import Formatter
|
|
||||||
from typing import Any, Dict, List, Sequence, Union
|
from typing import Any, Dict, List, Sequence, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
@ -132,10 +131,5 @@ class LLMChain(Chain, BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
|
def from_string(cls, llm: BaseLLM, template: str) -> Chain:
|
||||||
"""Create LLMChain from LLM and template."""
|
"""Create LLMChain from LLM and template."""
|
||||||
input_variables = {
|
prompt_template = PromptTemplate.from_template(template)
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
|
||||||
}
|
|
||||||
prompt_template = PromptTemplate(
|
|
||||||
input_variables=list(input_variables), template=template
|
|
||||||
)
|
|
||||||
return cls(llm=llm, prompt=prompt_template)
|
return cls(llm=llm, prompt=prompt_template)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Prompt schema definition."""
|
"""Prompt schema definition."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from string import Formatter
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -117,6 +118,14 @@ class PromptTemplate(BasePromptTemplate, BaseModel):
|
|||||||
template = f.read()
|
template = f.read()
|
||||||
return cls(input_variables=input_variables, template=template)
|
return cls(input_variables=input_variables, template=template)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_template(cls, template: str) -> PromptTemplate:
|
||||||
|
"""Load a prompt template from a template."""
|
||||||
|
input_variables = {
|
||||||
|
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||||
|
}
|
||||||
|
return cls(input_variables=list(input_variables), template=template)
|
||||||
|
|
||||||
|
|
||||||
# For backwards compatibility.
|
# For backwards compatibility.
|
||||||
Prompt = PromptTemplate
|
Prompt = PromptTemplate
|
||||||
|
@ -13,6 +13,27 @@ def test_prompt_valid() -> None:
|
|||||||
assert prompt.input_variables == input_variables
|
assert prompt.input_variables == input_variables
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_from_template() -> None:
|
||||||
|
"""Test prompts can be constructed from a template."""
|
||||||
|
# Single input variable.
|
||||||
|
template = "This is a {foo} test."
|
||||||
|
prompt = PromptTemplate.from_template(template)
|
||||||
|
expected_prompt = PromptTemplate(template=template, input_variables=["foo"])
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
# Multiple input variables.
|
||||||
|
template = "This {bar} is a {foo} test."
|
||||||
|
prompt = PromptTemplate.from_template(template)
|
||||||
|
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
# Multiple input variables with repeats.
|
||||||
|
template = "This {bar} is a {foo} test {foo}."
|
||||||
|
prompt = PromptTemplate.from_template(template)
|
||||||
|
expected_prompt = PromptTemplate(template=template, input_variables=["bar", "foo"])
|
||||||
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_missing_input_variables() -> None:
|
def test_prompt_missing_input_variables() -> None:
|
||||||
"""Test error is raised when input variables are not provided."""
|
"""Test error is raised when input variables are not provided."""
|
||||||
template = "This is a {foo} test."
|
template = "This is a {foo} test."
|
||||||
|
Loading…
Reference in New Issue
Block a user