Few Shot Chat Prompt (#8038)

Proposal for a few shot chat message example selector

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
William FH
2023-07-27 18:46:10 -07:00
committed by GitHub
parent 6dd18eee26
commit ecd4aae818
6 changed files with 735 additions and 138 deletions

View File

@@ -15,7 +15,10 @@ from langchain.prompts.example_selector import (
NGramOverlapExampleSelector,
SemanticSimilarityExampleSelector,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.loading import load_prompt
from langchain.prompts.pipeline import PipelinePromptTemplate
@@ -42,4 +45,5 @@ __all__ = [
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"FewShotChatMessagePromptTemplate",
]

View File

@@ -318,14 +318,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
input_variables: List[str]
"""List of input variables."""
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
messages: List[
Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
]
"""List of messages consisting of either message prompt templates or messages."""
def __add__(self, other: Any) -> ChatPromptTemplate:
# Allow for easy combining
if isinstance(other, ChatPromptTemplate):
return ChatPromptTemplate(messages=self.messages + other.messages)
elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)):
elif isinstance(
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
):
return ChatPromptTemplate(messages=self.messages + [other])
elif isinstance(other, str):
prompt = HumanMessagePromptTemplate.from_template(other)
@@ -349,7 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
messages = values["messages"]
input_vars = set()
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
input_vars.update(message.input_variables)
if "partial_variables" in values:
input_vars = input_vars - set(values["partial_variables"])
@@ -475,7 +479,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
for message_template in self.messages:
if isinstance(message_template, BaseMessage):
result.extend([message_template])
elif isinstance(message_template, BaseMessagePromptTemplate):
elif isinstance(
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
):
rel_params = {
k: v
for k, v in kwargs.items()

View File

@@ -1,24 +1,24 @@
"""Prompt template that contains few shot examples."""
from typing import Any, Dict, List, Optional
from __future__ import annotations
from pydantic import Extra, root_validator
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.prompts.base import (
DEFAULT_FORMATTER_MAPPING,
StringPromptTemplate,
check_valid_template,
)
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.prompt import PromptTemplate
from langchain.schema.messages import BaseMessage, get_buffer_string
class FewShotPromptTemplate(StringPromptTemplate):
class _FewShotPromptTemplateMixin(BaseModel):
"""Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
return False
examples: Optional[List[dict]] = None
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
@@ -27,26 +27,11 @@ class FewShotPromptTemplate(StringPromptTemplate):
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
class Config:
"""Configuration for this pydantic object."""
suffix: str
"""A prompt template string to put after the examples."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
validate_template: bool = True
"""Whether or not to try validating the template."""
extra = Extra.forbid
arbitrary_types_allowed = True
@root_validator(pre=True)
def check_examples_and_selector(cls, values: Dict) -> Dict:
@@ -65,6 +50,58 @@ class FewShotPromptTemplate(StringPromptTemplate):
return values
def _get_examples(self, **kwargs: Any) -> List[dict]:
"""Get the examples to use for formatting the prompt.
Args:
**kwargs: Keyword arguments to be passed to the example selector.
Returns:
List of examples.
"""
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError(
"One of 'examples' and 'example_selector' should be provided"
)
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
"""Prompt template that contains few shot examples."""
@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.
Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
return False
validate_template: bool = True
"""Whether or not to try validating the template."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
example_prompt: PromptTemplate
"""PromptTemplate used to format an individual example."""
suffix: str
"""A prompt template string to put after the examples."""
example_separator: str = "\n\n"
"""String separator used to join the prefix, the examples, and suffix."""
prefix: str = ""
"""A prompt template string to put before the examples."""
template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix, and input variables are consistent."""
@@ -82,19 +119,11 @@ class FewShotPromptTemplate(StringPromptTemplate):
extra = Extra.forbid
arbitrary_types_allowed = True
def _get_examples(self, **kwargs: Any) -> List[dict]:
if self.examples is not None:
return self.examples
elif self.example_selector is not None:
return self.example_selector.select_examples(kwargs)
else:
raise ValueError
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
**kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
@@ -132,3 +161,184 @@ class FewShotPromptTemplate(StringPromptTemplate):
if self.example_selector:
raise ValueError("Saving an example selector is not currently supported")
return super().dict(**kwargs)
class FewShotChatMessagePromptTemplate(
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
):
"""Chat prompt template that supports few-shot examples.
The high level structure of produced by this prompt template is a list of messages
consisting of prefix message(s), example message(s), and suffix message(s).
This structure enables creating a conversation with intermediate examples like:
System: You are a helpful AI Assistant
Human: What is 2+2?
AI: 4
Human: What is 2+3?
AI: 5
Human: What is 4+4?
This prompt template can be used to generate a fixed list of examples or else
to dynamically select examples based on the input.
Examples:
Prompt template with a fixed list of examples (matching the sample
conversation above):
.. code-block:: python
from langchain.schema import SystemMessage
from langchain.prompts import (
FewShotChatMessagePromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
AIMessagePromptTemplate
)
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
few_shot_prompt = FewShotChatMessagePromptTemplate(
examples=examples,
# This is a prompt template used to format each individual example.
example_prompt=(
HumanMessagePromptTemplate.from_template("{input}")
+ AIMessagePromptTemplate.from_template("{output}")
),
)
final_prompt = (
SystemMessagePromptTemplate.from_template(
"You are a helpful AI Assistant"
)
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
final_prompt.format(input="What is 4+4?")
Prompt template with dynamically selected examples:
.. code-block:: python
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
{"input": "2+4", "output": "6"},
# ...
]
to_vectorize = [
" ".join(example.values())
for example in examples
]
embeddings = OpenAIEmbeddings()
vectorstore = Chroma.from_texts(
to_vectorize, embeddings, metadatas=examples
)
example_selector = SemanticSimilarityExampleSelector(
vectorstore=vectorstore
)
from langchain.schema import SystemMessage
from langchain.prompts import HumanMessagePromptTemplate
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate
few_shot_prompt = FewShotChatMessagePromptTemplate(
# Which variable(s) will be passed to the example selector.
input_variables=["input"],
example_selector=example_selector,
# Define how each example will be formatted.
# In this case, each example will become 2 messages:
# 1 human, and 1 AI
example_prompt=(
HumanMessagePromptTemplate.from_template("{input}")
+ AIMessagePromptTemplate.from_template("{output}")
),
)
# Define the overall prompt.
final_prompt = (
SystemMessagePromptTemplate.from_template(
"You are a helpful AI Assistant"
)
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
# Show the prompt
print(final_prompt.format_messages(input="What's 3+3?"))
# Use within an LLM
from langchain.chat_models import ChatAnthropic
chain = final_prompt | ChatAnthropic()
chain.invoke({"input": "What's 3+3?"})
"""
@property
def lc_serializable(self) -> bool:
"""Return whether the prompt template is lc_serializable.
Returns:
Boolean indicating whether the prompt template is lc_serializable.
"""
return False
input_variables: List[str] = Field(default_factory=list)
"""A list of the names of the variables the prompt template will use
to pass to the example_selector, if provided."""
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
"""The class to format each example."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format kwargs into a list of messages.
Args:
**kwargs: keyword arguments to use for filling in templates in messages.
Returns:
A list of formatted messages with all template variables filled in.
"""
# Get the examples to use.
examples = self._get_examples(**kwargs)
examples = [
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
]
# Format the examples.
messages = [
message
for example in examples
for message in self.example_prompt.format_messages(**example)
]
return messages
def format(self, **kwargs: Any) -> str:
"""Format the prompt with inputs generating a string.
Use this method to generate a string representation of a prompt consisting
of chat messages.
Useful for feeding into a string based completion language model or debugging.
Args:
**kwargs: keyword arguments to use for formatting.
Returns:
A string representation of the prompt
"""
messages = self.format_messages(**kwargs)
return get_buffer_string(messages)

View File

@@ -1,10 +1,21 @@
"""Test few shot prompt template."""
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Sequence, Tuple
import pytest
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts import (
AIMessagePromptTemplate,
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.chat import SystemMessagePromptTemplate
from langchain.prompts.example_selector.base import BaseExampleSelector
from langchain.prompts.few_shot import (
FewShotChatMessagePromptTemplate,
FewShotPromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import AIMessage, HumanMessage, SystemMessage
EXAMPLE_PROMPT = PromptTemplate(
input_variables=["question", "answer"], template="{question}: {answer}"
@@ -267,3 +278,93 @@ def test_prompt_jinja2_extra_input_variables(
example_prompt=example_jinja2_prompt[0],
template_format="jinja2",
)
def test_few_shot_chat_message_prompt_template() -> None:
"""Tests for few shot chat message template."""
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
example_prompt=example_prompt,
examples=examples,
)
final_prompt: ChatPromptTemplate = (
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
messages = final_prompt.format_messages(input="100 + 1")
assert messages == [
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
HumanMessage(content="2+2", additional_kwargs={}, example=False),
AIMessage(content="4", additional_kwargs={}, example=False),
HumanMessage(content="2+3", additional_kwargs={}, example=False),
AIMessage(content="5", additional_kwargs={}, example=False),
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
]
class AsIsSelector(BaseExampleSelector):
"""An example selector for testing purposes.
This selector returns the examples as-is.
"""
def __init__(self, examples: Sequence[Dict[str, str]]) -> None:
"""Initializes the selector."""
self.examples = examples
def add_example(self, example: Dict[str, str]) -> Any:
"""Adds an example to the selector."""
raise NotImplementedError()
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
"""Select which examples to use based on the inputs."""
return list(self.examples)
def test_few_shot_chat_message_prompt_template_with_selector() -> None:
"""Tests for few shot chat message template with an example selector."""
examples = [
{"input": "2+2", "output": "4"},
{"input": "2+3", "output": "5"},
]
example_selector = AsIsSelector(examples)
example_prompt = ChatPromptTemplate.from_messages(
[
HumanMessagePromptTemplate.from_template("{input}"),
AIMessagePromptTemplate.from_template("{output}"),
]
)
few_shot_prompt = FewShotChatMessagePromptTemplate(
input_variables=["input"],
example_prompt=example_prompt,
example_selector=example_selector,
)
final_prompt: ChatPromptTemplate = (
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
+ few_shot_prompt
+ HumanMessagePromptTemplate.from_template("{input}")
)
messages = final_prompt.format_messages(input="100 + 1")
assert messages == [
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
HumanMessage(content="2+2", additional_kwargs={}, example=False),
AIMessage(content="4", additional_kwargs={}, example=False),
HumanMessage(content="2+3", additional_kwargs={}, example=False),
AIMessage(content="5", additional_kwargs={}, example=False),
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
]