mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
Few Shot Chat Prompt (#8038)
Proposal for a few shot chat message example selector --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
@@ -15,7 +15,10 @@ from langchain.prompts.example_selector import (
|
||||
NGramOverlapExampleSelector,
|
||||
SemanticSimilarityExampleSelector,
|
||||
)
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts.few_shot import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
)
|
||||
from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
|
||||
from langchain.prompts.loading import load_prompt
|
||||
from langchain.prompts.pipeline import PipelinePromptTemplate
|
||||
@@ -42,4 +45,5 @@ __all__ = [
|
||||
"StringPromptTemplate",
|
||||
"SystemMessagePromptTemplate",
|
||||
"load_prompt",
|
||||
"FewShotChatMessagePromptTemplate",
|
||||
]
|
||||
|
@@ -318,14 +318,18 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
|
||||
input_variables: List[str]
|
||||
"""List of input variables."""
|
||||
messages: List[Union[BaseMessagePromptTemplate, BaseMessage]]
|
||||
messages: List[
|
||||
Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate]
|
||||
]
|
||||
"""List of messages consisting of either message prompt templates or messages."""
|
||||
|
||||
def __add__(self, other: Any) -> ChatPromptTemplate:
|
||||
# Allow for easy combining
|
||||
if isinstance(other, ChatPromptTemplate):
|
||||
return ChatPromptTemplate(messages=self.messages + other.messages)
|
||||
elif isinstance(other, (BaseMessagePromptTemplate, BaseMessage)):
|
||||
elif isinstance(
|
||||
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
|
||||
):
|
||||
return ChatPromptTemplate(messages=self.messages + [other])
|
||||
elif isinstance(other, str):
|
||||
prompt = HumanMessagePromptTemplate.from_template(other)
|
||||
@@ -349,7 +353,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
messages = values["messages"]
|
||||
input_vars = set()
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessagePromptTemplate):
|
||||
if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)):
|
||||
input_vars.update(message.input_variables)
|
||||
if "partial_variables" in values:
|
||||
input_vars = input_vars - set(values["partial_variables"])
|
||||
@@ -475,7 +479,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
|
||||
for message_template in self.messages:
|
||||
if isinstance(message_template, BaseMessage):
|
||||
result.extend([message_template])
|
||||
elif isinstance(message_template, BaseMessagePromptTemplate):
|
||||
elif isinstance(
|
||||
message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
|
||||
):
|
||||
rel_params = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
|
@@ -1,24 +1,24 @@
|
||||
"""Prompt template that contains few shot examples."""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import Extra, root_validator
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.prompts.base import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
StringPromptTemplate,
|
||||
check_valid_template,
|
||||
)
|
||||
from langchain.prompts.chat import BaseChatPromptTemplate, BaseMessagePromptTemplate
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
|
||||
|
||||
class FewShotPromptTemplate(StringPromptTemplate):
|
||||
class _FewShotPromptTemplateMixin(BaseModel):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return False
|
||||
|
||||
examples: Optional[List[dict]] = None
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
@@ -27,26 +27,11 @@ class FewShotPromptTemplate(StringPromptTemplate):
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
suffix: str
|
||||
"""A prompt template string to put after the examples."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prompt template string to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
validate_template: bool = True
|
||||
"""Whether or not to try validating the template."""
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def check_examples_and_selector(cls, values: Dict) -> Dict:
|
||||
@@ -65,6 +50,58 @@ class FewShotPromptTemplate(StringPromptTemplate):
|
||||
|
||||
return values
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
"""Get the examples to use for formatting the prompt.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments to be passed to the example selector.
|
||||
|
||||
Returns:
|
||||
List of examples.
|
||||
"""
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"One of 'examples' and 'example_selector' should be provided"
|
||||
)
|
||||
|
||||
|
||||
class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate):
|
||||
"""Prompt template that contains few shot examples."""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Return whether the prompt template is lc_serializable.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether the prompt template is lc_serializable.
|
||||
"""
|
||||
return False
|
||||
|
||||
validate_template: bool = True
|
||||
"""Whether or not to try validating the template."""
|
||||
|
||||
input_variables: List[str]
|
||||
"""A list of the names of the variables the prompt template expects."""
|
||||
|
||||
example_prompt: PromptTemplate
|
||||
"""PromptTemplate used to format an individual example."""
|
||||
|
||||
suffix: str
|
||||
"""A prompt template string to put after the examples."""
|
||||
|
||||
example_separator: str = "\n\n"
|
||||
"""String separator used to join the prefix, the examples, and suffix."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prompt template string to put before the examples."""
|
||||
|
||||
template_format: str = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
@root_validator()
|
||||
def template_is_valid(cls, values: Dict) -> Dict:
|
||||
"""Check that prefix, suffix, and input variables are consistent."""
|
||||
@@ -82,19 +119,11 @@ class FewShotPromptTemplate(StringPromptTemplate):
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _get_examples(self, **kwargs: Any) -> List[dict]:
|
||||
if self.examples is not None:
|
||||
return self.examples
|
||||
elif self.example_selector is not None:
|
||||
return self.example_selector.select_examples(kwargs)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
**kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
@@ -132,3 +161,184 @@ class FewShotPromptTemplate(StringPromptTemplate):
|
||||
if self.example_selector:
|
||||
raise ValueError("Saving an example selector is not currently supported")
|
||||
return super().dict(**kwargs)
|
||||
|
||||
|
||||
class FewShotChatMessagePromptTemplate(
|
||||
BaseChatPromptTemplate, _FewShotPromptTemplateMixin
|
||||
):
|
||||
"""Chat prompt template that supports few-shot examples.
|
||||
|
||||
The high level structure of produced by this prompt template is a list of messages
|
||||
consisting of prefix message(s), example message(s), and suffix message(s).
|
||||
|
||||
This structure enables creating a conversation with intermediate examples like:
|
||||
|
||||
System: You are a helpful AI Assistant
|
||||
Human: What is 2+2?
|
||||
AI: 4
|
||||
Human: What is 2+3?
|
||||
AI: 5
|
||||
Human: What is 4+4?
|
||||
|
||||
This prompt template can be used to generate a fixed list of examples or else
|
||||
to dynamically select examples based on the input.
|
||||
|
||||
Examples:
|
||||
|
||||
Prompt template with a fixed list of examples (matching the sample
|
||||
conversation above):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.prompts import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
SystemMessagePromptTemplate,
|
||||
AIMessagePromptTemplate
|
||||
)
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
examples=examples,
|
||||
# This is a prompt template used to format each individual example.
|
||||
example_prompt=(
|
||||
HumanMessagePromptTemplate.from_template("{input}")
|
||||
+ AIMessagePromptTemplate.from_template("{output}")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
final_prompt = (
|
||||
SystemMessagePromptTemplate.from_template(
|
||||
"You are a helpful AI Assistant"
|
||||
)
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
|
||||
final_prompt.format(input="What is 4+4?")
|
||||
|
||||
Prompt template with dynamically selected examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.prompts import SemanticSimilarityExampleSelector
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
{"input": "2+4", "output": "6"},
|
||||
# ...
|
||||
]
|
||||
|
||||
to_vectorize = [
|
||||
" ".join(example.values())
|
||||
for example in examples
|
||||
]
|
||||
embeddings = OpenAIEmbeddings()
|
||||
vectorstore = Chroma.from_texts(
|
||||
to_vectorize, embeddings, metadatas=examples
|
||||
)
|
||||
example_selector = SemanticSimilarityExampleSelector(
|
||||
vectorstore=vectorstore
|
||||
)
|
||||
|
||||
from langchain.schema import SystemMessage
|
||||
from langchain.prompts import HumanMessagePromptTemplate
|
||||
from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
# Which variable(s) will be passed to the example selector.
|
||||
input_variables=["input"],
|
||||
example_selector=example_selector,
|
||||
# Define how each example will be formatted.
|
||||
# In this case, each example will become 2 messages:
|
||||
# 1 human, and 1 AI
|
||||
example_prompt=(
|
||||
HumanMessagePromptTemplate.from_template("{input}")
|
||||
+ AIMessagePromptTemplate.from_template("{output}")
|
||||
),
|
||||
)
|
||||
# Define the overall prompt.
|
||||
final_prompt = (
|
||||
SystemMessagePromptTemplate.from_template(
|
||||
"You are a helpful AI Assistant"
|
||||
)
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
# Show the prompt
|
||||
print(final_prompt.format_messages(input="What's 3+3?"))
|
||||
|
||||
# Use within an LLM
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
chain = final_prompt | ChatAnthropic()
|
||||
chain.invoke({"input": "What's 3+3?"})
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
"""Return whether the prompt template is lc_serializable.
|
||||
|
||||
Returns:
|
||||
Boolean indicating whether the prompt template is lc_serializable.
|
||||
"""
|
||||
return False
|
||||
|
||||
input_variables: List[str] = Field(default_factory=list)
|
||||
"""A list of the names of the variables the prompt template will use
|
||||
to pass to the example_selector, if provided."""
|
||||
example_prompt: Union[BaseMessagePromptTemplate, BaseChatPromptTemplate]
|
||||
"""The class to format each example."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
|
||||
"""Format kwargs into a list of messages.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for filling in templates in messages.
|
||||
|
||||
Returns:
|
||||
A list of formatted messages with all template variables filled in.
|
||||
"""
|
||||
# Get the examples to use.
|
||||
examples = self._get_examples(**kwargs)
|
||||
examples = [
|
||||
{k: e[k] for k in self.example_prompt.input_variables} for e in examples
|
||||
]
|
||||
# Format the examples.
|
||||
messages = [
|
||||
message
|
||||
for example in examples
|
||||
for message in self.example_prompt.format_messages(**example)
|
||||
]
|
||||
return messages
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with inputs generating a string.
|
||||
|
||||
Use this method to generate a string representation of a prompt consisting
|
||||
of chat messages.
|
||||
|
||||
Useful for feeding into a string based completion language model or debugging.
|
||||
|
||||
Args:
|
||||
**kwargs: keyword arguments to use for formatting.
|
||||
|
||||
Returns:
|
||||
A string representation of the prompt
|
||||
"""
|
||||
messages = self.format_messages(**kwargs)
|
||||
return get_buffer_string(messages)
|
||||
|
@@ -1,10 +1,21 @@
|
||||
"""Test few shot prompt template."""
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||
from langchain.prompts import (
|
||||
AIMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
)
|
||||
from langchain.prompts.chat import SystemMessagePromptTemplate
|
||||
from langchain.prompts.example_selector.base import BaseExampleSelector
|
||||
from langchain.prompts.few_shot import (
|
||||
FewShotChatMessagePromptTemplate,
|
||||
FewShotPromptTemplate,
|
||||
)
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
EXAMPLE_PROMPT = PromptTemplate(
|
||||
input_variables=["question", "answer"], template="{question}: {answer}"
|
||||
@@ -267,3 +278,93 @@ def test_prompt_jinja2_extra_input_variables(
|
||||
example_prompt=example_jinja2_prompt[0],
|
||||
template_format="jinja2",
|
||||
)
|
||||
|
||||
|
||||
def test_few_shot_chat_message_prompt_template() -> None:
|
||||
"""Tests for few shot chat message template."""
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
|
||||
example_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessagePromptTemplate.from_template("{output}"),
|
||||
]
|
||||
)
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
input_variables=["input"],
|
||||
example_prompt=example_prompt,
|
||||
examples=examples,
|
||||
)
|
||||
final_prompt: ChatPromptTemplate = (
|
||||
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
|
||||
messages = final_prompt.format_messages(input="100 + 1")
|
||||
assert messages == [
|
||||
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
||||
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
||||
AIMessage(content="4", additional_kwargs={}, example=False),
|
||||
HumanMessage(content="2+3", additional_kwargs={}, example=False),
|
||||
AIMessage(content="5", additional_kwargs={}, example=False),
|
||||
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
||||
|
||||
class AsIsSelector(BaseExampleSelector):
|
||||
"""An example selector for testing purposes.
|
||||
|
||||
This selector returns the examples as-is.
|
||||
"""
|
||||
|
||||
def __init__(self, examples: Sequence[Dict[str, str]]) -> None:
|
||||
"""Initializes the selector."""
|
||||
self.examples = examples
|
||||
|
||||
def add_example(self, example: Dict[str, str]) -> Any:
|
||||
"""Adds an example to the selector."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
|
||||
"""Select which examples to use based on the inputs."""
|
||||
return list(self.examples)
|
||||
|
||||
|
||||
def test_few_shot_chat_message_prompt_template_with_selector() -> None:
|
||||
"""Tests for few shot chat message template with an example selector."""
|
||||
examples = [
|
||||
{"input": "2+2", "output": "4"},
|
||||
{"input": "2+3", "output": "5"},
|
||||
]
|
||||
example_selector = AsIsSelector(examples)
|
||||
example_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
HumanMessagePromptTemplate.from_template("{input}"),
|
||||
AIMessagePromptTemplate.from_template("{output}"),
|
||||
]
|
||||
)
|
||||
|
||||
few_shot_prompt = FewShotChatMessagePromptTemplate(
|
||||
input_variables=["input"],
|
||||
example_prompt=example_prompt,
|
||||
example_selector=example_selector,
|
||||
)
|
||||
final_prompt: ChatPromptTemplate = (
|
||||
SystemMessagePromptTemplate.from_template("You are a helpful AI Assistant")
|
||||
+ few_shot_prompt
|
||||
+ HumanMessagePromptTemplate.from_template("{input}")
|
||||
)
|
||||
messages = final_prompt.format_messages(input="100 + 1")
|
||||
assert messages == [
|
||||
SystemMessage(content="You are a helpful AI Assistant", additional_kwargs={}),
|
||||
HumanMessage(content="2+2", additional_kwargs={}, example=False),
|
||||
AIMessage(content="4", additional_kwargs={}, example=False),
|
||||
HumanMessage(content="2+3", additional_kwargs={}, example=False),
|
||||
AIMessage(content="5", additional_kwargs={}, example=False),
|
||||
HumanMessage(content="100 + 1", additional_kwargs={}, example=False),
|
||||
]
|
||||
|
Reference in New Issue
Block a user