mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-04 10:42:55 +00:00
fyi @eyurtsev was failing a unit test
This commit is contained in:
parent
e66759cc9d
commit
877d384bc9
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from string import Formatter
|
from string import Formatter
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
|
||||||
@ -16,24 +16,12 @@ from langchain.prompts.base import (
|
|||||||
|
|
||||||
|
|
||||||
class PromptTemplate(StringPromptTemplate):
|
class PromptTemplate(StringPromptTemplate):
|
||||||
"""A prompt template for a language model.
|
"""Schema to represent a prompt for an LLM.
|
||||||
|
|
||||||
A prompt template consists of a string template. It accepts a set of parameters
|
|
||||||
from the user that can be used to generate a prompt for a language model.
|
|
||||||
|
|
||||||
The template can be formatted using either f-strings (default) or jinja2 syntax.
|
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain import PromptTemplate
|
from langchain import PromptTemplate
|
||||||
|
|
||||||
# Instantiation using from_template (recommended)
|
|
||||||
prompt = PromptTemplate.from_template("Say {foo}")
|
|
||||||
prompt.format(foo="bar")
|
|
||||||
|
|
||||||
# Instantiation using initializer
|
|
||||||
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -56,7 +44,6 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
"""Whether or not to try validating the template."""
|
"""Whether or not to try validating the template."""
|
||||||
|
|
||||||
def __add__(self, other: Any) -> PromptTemplate:
|
def __add__(self, other: Any) -> PromptTemplate:
|
||||||
"""Override the + operator to allow for combining prompt templates."""
|
|
||||||
# Allow for easy combining
|
# Allow for easy combining
|
||||||
if isinstance(other, PromptTemplate):
|
if isinstance(other, PromptTemplate):
|
||||||
if self.template_format != "f-string":
|
if self.template_format != "f-string":
|
||||||
@ -166,7 +153,6 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
template_file: The path to the file containing the prompt template.
|
template_file: The path to the file containing the prompt template.
|
||||||
input_variables: A list of variable names the final prompt template
|
input_variables: A list of variable names the final prompt template
|
||||||
will expect.
|
will expect.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The prompt loaded from the file.
|
The prompt loaded from the file.
|
||||||
"""
|
"""
|
||||||
@ -175,52 +161,25 @@ class PromptTemplate(StringPromptTemplate):
|
|||||||
return cls(input_variables=input_variables, template=template, **kwargs)
|
return cls(input_variables=input_variables, template=template, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_template(
|
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
|
||||||
cls,
|
"""Load a prompt template from a template."""
|
||||||
template: str,
|
if "template_format" in kwargs and kwargs["template_format"] == "jinja2":
|
||||||
*,
|
|
||||||
template_format: str = "f-string",
|
|
||||||
partial_variables: Optional[Dict[str, Any]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> PromptTemplate:
|
|
||||||
"""Load a prompt template from a template.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
template: The template to load.
|
|
||||||
template_format: The format of the template. Use `jinja2` for jinja2,
|
|
||||||
and `f-string` or None for f-strings.
|
|
||||||
partial_variables: A dictionary of variables that can be used to partially
|
|
||||||
fill in the template. For example, if the template is
|
|
||||||
`"{variable1} {variable2}"`, and `partial_variables` is
|
|
||||||
`{"variable1": "foo"}`, then the final prompt will be
|
|
||||||
`"foo {variable2}"`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The prompt template loaded from the template.
|
|
||||||
"""
|
|
||||||
if template_format == "jinja2":
|
|
||||||
# Get the variables for the template
|
# Get the variables for the template
|
||||||
input_variables = _get_jinja2_variables_from_template(template)
|
input_variables = _get_jinja2_variables_from_template(template)
|
||||||
elif template_format == "f-string":
|
|
||||||
|
else:
|
||||||
input_variables = {
|
input_variables = {
|
||||||
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
v for _, v, _, _ in Formatter().parse(template) if v is not None
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported template format: {template_format}")
|
|
||||||
|
|
||||||
_partial_variables = partial_variables or {}
|
if "partial_variables" in kwargs:
|
||||||
|
partial_variables = kwargs["partial_variables"]
|
||||||
if _partial_variables:
|
|
||||||
input_variables = {
|
input_variables = {
|
||||||
var for var in input_variables if var not in _partial_variables
|
var for var in input_variables if var not in partial_variables
|
||||||
}
|
}
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
input_variables=sorted(input_variables),
|
input_variables=list(sorted(input_variables)), template=template, **kwargs
|
||||||
template=template,
|
|
||||||
template_format=template_format,
|
|
||||||
partial_variables=_partial_variables,
|
|
||||||
**kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -98,8 +98,7 @@
|
|||||||
"name"
|
"name"
|
||||||
],
|
],
|
||||||
"template": "hello {name}!",
|
"template": "hello {name}!",
|
||||||
"template_format": "f-string",
|
"template_format": "f-string"
|
||||||
"partial_variables": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -177,8 +176,7 @@
|
|||||||
"name"
|
"name"
|
||||||
],
|
],
|
||||||
"template": "hello {name}!",
|
"template": "hello {name}!",
|
||||||
"template_format": "f-string",
|
"template_format": "f-string"
|
||||||
"partial_variables": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -247,8 +245,7 @@
|
|||||||
"name"
|
"name"
|
||||||
],
|
],
|
||||||
"template": "hello {name}!",
|
"template": "hello {name}!",
|
||||||
"template_format": "f-string",
|
"template_format": "f-string"
|
||||||
"partial_variables": {}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -280,25 +277,3 @@
|
|||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
# ---
|
# ---
|
||||||
# name: test_serialize_prompt
|
|
||||||
'''
|
|
||||||
{
|
|
||||||
"lc": 1,
|
|
||||||
"type": "constructor",
|
|
||||||
"id": [
|
|
||||||
"langchain",
|
|
||||||
"prompts",
|
|
||||||
"prompt",
|
|
||||||
"PromptTemplate"
|
|
||||||
],
|
|
||||||
"kwargs": {
|
|
||||||
"input_variables": [
|
|
||||||
"name"
|
|
||||||
],
|
|
||||||
"template": "hello {name}!",
|
|
||||||
"template_format": "f-string",
|
|
||||||
"partial_variables": {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
'''
|
|
||||||
# ---
|
|
||||||
|
@ -129,12 +129,6 @@ def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
|||||||
del os.environ["OPENAI_API_KEY"]
|
del os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
|
|
||||||
def test_serialize_prompt(snapshot: Any) -> None:
|
|
||||||
"""Test that prompt is serialized correctly"""
|
|
||||||
prompt = PromptTemplate.from_template("hello {name}!")
|
|
||||||
assert dumps(prompt, pretty=True) == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||||
llm = OpenAI(
|
llm = OpenAI(
|
||||||
|
@ -161,10 +161,6 @@ Will it get confused{ }?
|
|||||||
)
|
)
|
||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
|
||||||
def test_prompt_from_jinja2_template_multiple_inputs() -> None:
|
|
||||||
"""Test with multiple input variables."""
|
|
||||||
# Multiple input variables.
|
# Multiple input variables.
|
||||||
template = """\
|
template = """\
|
||||||
Hello world
|
Hello world
|
||||||
@ -190,10 +186,7 @@ You just set bar boolean variable to true
|
|||||||
|
|
||||||
assert prompt == expected_prompt
|
assert prompt == expected_prompt
|
||||||
|
|
||||||
|
# Multiple input variables with repeats.
|
||||||
@pytest.mark.requires("jinja2")
|
|
||||||
def test_prompt_from_jinja2_template_multiple_inputs_with_repeats() -> None:
|
|
||||||
"""Test with multiple input variables and repeats."""
|
|
||||||
template = """\
|
template = """\
|
||||||
Hello world
|
Hello world
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user