mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-18 21:09:00 +00:00
core[patch]: fix chat prompt partial messages placeholder var (#16918)
This commit is contained in:
parent
3b0fa9079d
commit
c29e9b6412
@ -47,9 +47,7 @@ class BasePromptTemplate(
|
|||||||
If not provided, all variables are assumed to be strings."""
|
If not provided, all variables are assumed to be strings."""
|
||||||
output_parser: Optional[BaseOutputParser] = None
|
output_parser: Optional[BaseOutputParser] = None
|
||||||
"""How to parse the output of calling an LLM on this formatted prompt."""
|
"""How to parse the output of calling an LLM on this formatted prompt."""
|
||||||
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
|
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
default_factory=dict
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
@ -143,8 +141,7 @@ class BasePromptTemplate(
|
|||||||
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
|
||||||
# Get partial params:
|
# Get partial params:
|
||||||
partial_kwargs = {
|
partial_kwargs = {
|
||||||
k: v if isinstance(v, str) else v()
|
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
|
||||||
for k, v in self.partial_variables.items()
|
|
||||||
}
|
}
|
||||||
return {**partial_kwargs, **kwargs}
|
return {**partial_kwargs, **kwargs}
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -130,13 +129,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
|||||||
f"variable {self.variable_name} should be a list of base messages, "
|
f"variable {self.variable_name} should be a list of base messages, "
|
||||||
f"got {value}"
|
f"got {value}"
|
||||||
)
|
)
|
||||||
for v in convert_to_messages(value):
|
return convert_to_messages(value)
|
||||||
if not isinstance(v, BaseMessage):
|
|
||||||
raise ValueError(
|
|
||||||
f"variable {self.variable_name} should be a list of base messages,"
|
|
||||||
f" got {value}"
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_variables(self) -> List[str]:
|
def input_variables(self) -> List[str]:
|
||||||
@ -755,13 +748,20 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
|
|
||||||
# Automatically infer input variables from messages
|
# Automatically infer input variables from messages
|
||||||
input_vars: Set[str] = set()
|
input_vars: Set[str] = set()
|
||||||
|
partial_vars: Dict[str, Any] = {}
|
||||||
for _message in _messages:
|
for _message in _messages:
|
||||||
if isinstance(
|
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
||||||
|
partial_vars[_message.variable_name] = []
|
||||||
|
elif isinstance(
|
||||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||||
):
|
):
|
||||||
input_vars.update(_message.input_variables)
|
input_vars.update(_message.input_variables)
|
||||||
|
|
||||||
return cls(input_variables=sorted(input_vars), messages=_messages)
|
return cls(
|
||||||
|
input_variables=sorted(input_vars),
|
||||||
|
messages=_messages,
|
||||||
|
partial_variables=partial_vars,
|
||||||
|
)
|
||||||
|
|
||||||
def format(self, **kwargs: Any) -> str:
|
def format(self, **kwargs: Any) -> str:
|
||||||
"""Format the chat template into a string.
|
"""Format the chat template into a string.
|
||||||
@ -799,7 +799,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
|
|||||||
raise ValueError(f"Unexpected input: {message_template}")
|
raise ValueError(f"Unexpected input: {message_template}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
|
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
|
||||||
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
"""Get a new ChatPromptTemplate with some input variables already filled in.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -503,9 +503,27 @@ def test_messages_placeholder() -> None:
|
|||||||
prompt.format_messages()
|
prompt.format_messages()
|
||||||
prompt = MessagesPlaceholder("history", optional=True)
|
prompt = MessagesPlaceholder("history", optional=True)
|
||||||
assert prompt.format_messages() == []
|
assert prompt.format_messages() == []
|
||||||
prompt.format_messages(
|
assert prompt.format_messages(
|
||||||
history=[("system", "You are an AI assistant."), "Hello!"]
|
history=[("system", "You are an AI assistant."), "Hello!"]
|
||||||
) == [
|
) == [
|
||||||
SystemMessage(content="You are an AI assistant."),
|
SystemMessage(content="You are an AI assistant."),
|
||||||
HumanMessage(content="Hello!"),
|
HumanMessage(content="Hello!"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_prompt_message_placeholder_partial() -> None:
|
||||||
|
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
|
||||||
|
prompt = prompt.partial(history=[("system", "foo")])
|
||||||
|
assert prompt.format_messages() == [SystemMessage(content="foo")]
|
||||||
|
assert prompt.format_messages(history=[("system", "bar")]) == [
|
||||||
|
SystemMessage(content="bar")
|
||||||
|
]
|
||||||
|
|
||||||
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[
|
||||||
|
MessagesPlaceholder("history", optional=True),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert prompt.format_messages() == []
|
||||||
|
prompt = prompt.partial(history=[("system", "foo")])
|
||||||
|
assert prompt.format_messages() == [SystemMessage(content="foo")]
|
||||||
|
@ -1544,7 +1544,8 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"partial_variables": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"middle": [],
|
"middle": [],
|
||||||
@ -1617,7 +1618,8 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"partial_variables": {}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"middle": [],
|
"middle": [],
|
||||||
|
@ -191,7 +191,8 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
|
"partial_variables": {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user