core[patch]: fix chat prompt partial messages placeholder var (#16918)

This commit is contained in:
Bagatur 2024-02-02 10:23:37 -08:00 committed by GitHub
parent 3b0fa9079d
commit c29e9b6412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 38 additions and 20 deletions

View File

@ -47,9 +47,7 @@ class BasePromptTemplate(
If not provided, all variables are assumed to be strings.""" If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt.""" """How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( partial_variables: Mapping[str, Any] = Field(default_factory=dict)
default_factory=dict
)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
@ -143,8 +141,7 @@ class BasePromptTemplate(
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]: def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params: # Get partial params:
partial_kwargs = { partial_kwargs = {
k: v if isinstance(v, str) else v() k: v if not callable(v) else v() for k, v in self.partial_variables.items()
for k, v in self.partial_variables.items()
} }
return {**partial_kwargs, **kwargs} return {**partial_kwargs, **kwargs}

View File

@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
@ -130,13 +129,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, " f"variable {self.variable_name} should be a list of base messages, "
f"got {value}" f"got {value}"
) )
for v in convert_to_messages(value): return convert_to_messages(value)
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
@property @property
def input_variables(self) -> List[str]: def input_variables(self) -> List[str]:
@ -755,13 +748,20 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
# Automatically infer input variables from messages # Automatically infer input variables from messages
input_vars: Set[str] = set() input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages: for _message in _messages:
if isinstance( if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
): ):
input_vars.update(_message.input_variables) input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages) return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
)
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string. """Format the chat template into a string.
@ -799,7 +799,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
raise ValueError(f"Unexpected input: {message_template}") raise ValueError(f"Unexpected input: {message_template}")
return result return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate: def partial(self, **kwargs: Any) -> ChatPromptTemplate:
"""Get a new ChatPromptTemplate with some input variables already filled in. """Get a new ChatPromptTemplate with some input variables already filled in.
Args: Args:

View File

@ -503,9 +503,27 @@ def test_messages_placeholder() -> None:
prompt.format_messages() prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True) prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == [] assert prompt.format_messages() == []
prompt.format_messages( assert prompt.format_messages(
history=[("system", "You are an AI assistant."), "Hello!"] history=[("system", "You are an AI assistant."), "Hello!"]
) == [ ) == [
SystemMessage(content="You are an AI assistant."), SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello!"), HumanMessage(content="Hello!"),
] ]
def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")]
assert prompt.format_messages(history=[("system", "bar")]) == [
SystemMessage(content="bar")
]
prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder("history", optional=True),
]
)
assert prompt.format_messages() == []
prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")]

View File

@ -1544,7 +1544,8 @@
} }
} }
} }
] ],
"partial_variables": {}
} }
}, },
"middle": [], "middle": [],
@ -1617,7 +1618,8 @@
} }
} }
} }
] ],
"partial_variables": {}
} }
}, },
"middle": [], "middle": [],

View File

@ -191,7 +191,8 @@
} }
} }
} }
] ],
"partial_variables": {}
} }
} }
} }