diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index dc13609b89f..7f91ce3e1bb 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -820,11 +820,17 @@ class ChatPromptTemplate(BaseChatPromptTemplate): Examples: + .. versionchanged:: 0.2.24 + + You can pass any Message-like formats supported by + ``ChatPromptTemplate.from_messages()`` directly to ``ChatPromptTemplate()`` + init. + .. code-block:: python from langchain_core.prompts import ChatPromptTemplate - template = ChatPromptTemplate.from_messages([ + template = ChatPromptTemplate([ ("system", "You are a helpful AI bot. Your name is {name}."), ("human", "Hello, how are you doing?"), ("ai", "I'm doing well, thanks!"), @@ -855,7 +861,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): # you can initialize the template with a MessagesPlaceholder # either using the class directly or with the shorthand tuple syntax: - template = ChatPromptTemplate.from_messages([ + template = ChatPromptTemplate([ ("system", "You are a helpful AI bot."), # Means the template will receive an optional list of messages under # the "conversation" key @@ -897,7 +903,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): from langchain_core.prompts import ChatPromptTemplate - template = ChatPromptTemplate.from_messages([ + template = ChatPromptTemplate([ ("system", "You are a helpful AI bot. Your name is Carl."), ("human", "{user_input}"), ]) @@ -921,6 +927,86 @@ class ChatPromptTemplate(BaseChatPromptTemplate): validate_template: bool = False """Whether or not to try validating the template.""" + def __init__( + self, + messages: Sequence[MessageLikeRepresentation], + *, + template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + **kwargs: Any, + ) -> None: + """Create a chat prompt template from a variety of message formats. + + Args: + messages: sequence of message representations. + A message can be represented using the following formats: + (1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of + (message type, template); e.g., ("human", "{user_input}"), + (4) 2-tuple of (message class, template), (4) a string which is + shorthand for ("human", template); e.g., "{user_input}". + template_format: format of the template. Defaults to "f-string". + input_variables: A list of the names of the variables whose values are + required as inputs to the prompt. + optional_variables: A list of the names of the variables that are optional + in the prompt. + partial_variables: A dictionary of the partial variables the prompt + template carries. Partial variables populate the template so that you + don't need to pass them in every time you call the prompt. + validate_template: Whether to validate the template. + input_types: A dictionary of the types of the variables the prompt template + expects. If not provided, all variables are assumed to be strings. + + Returns: + A chat prompt template. + + Examples: + + Instantiation from a list of message templates: + + .. code-block:: python + + template = ChatPromptTemplate([ + ("human", "Hello, how are you?"), + ("ai", "I'm doing well, thanks!"), + ("human", "That's good to hear."), + ]) + + Instantiation from mixed message formats: + + .. code-block:: python + + template = ChatPromptTemplate([ + SystemMessage(content="hello"), + ("human", "Hello, how are you?"), + ]) + + """ + _messages = [ + _convert_to_message(message, template_format) for message in messages + ] + + # Automatically infer input variables from messages + input_vars: Set[str] = set() + optional_variables: Set[str] = set() + partial_vars: Dict[str, Any] = {} + for _message in _messages: + if isinstance(_message, MessagesPlaceholder) and _message.optional: + partial_vars[_message.variable_name] = [] + optional_variables.add(_message.variable_name) + elif isinstance( + _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) + ): + input_vars.update(_message.input_variables) + + kwargs = { + **dict( + input_variables=sorted(input_vars), + optional_variables=sorted(optional_variables), + partial_variables=partial_vars, + ), + **kwargs, + } + cast(Type[ChatPromptTemplate], super()).__init__(messages=_messages, **kwargs) + @classmethod def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" @@ -1097,29 +1183,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): Returns: a chat prompt template. """ - _messages = [ - _convert_to_message(message, template_format) for message in messages - ] - - # Automatically infer input variables from messages - input_vars: Set[str] = set() - optional_variables: Set[str] = set() - partial_vars: Dict[str, Any] = {} - for _message in _messages: - if isinstance(_message, MessagesPlaceholder) and _message.optional: - partial_vars[_message.variable_name] = [] - optional_variables.add(_message.variable_name) - elif isinstance( - _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) - ): - input_vars.update(_message.input_variables) - - return cls( - input_variables=sorted(input_vars), - optional_variables=sorted(optional_variables), - messages=_messages, - partial_variables=partial_vars, - ) + return cls(messages, template_format=template_format) def format_messages(self, **kwargs: Any) -> List[BaseMessage]: """Format the chat template into a list of finalized messages. diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 4a8fd4ede3d..a52cf75e138 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -438,7 +438,7 @@ def test_chat_prompt_template_indexing() -> None: message1 = SystemMessage(content="foo") message2 = HumanMessage(content="bar") message3 = HumanMessage(content="baz") - template = ChatPromptTemplate.from_messages([message1, message2, message3]) + template = ChatPromptTemplate([message1, message2, message3]) assert template[0] == message1 assert template[1] == message2 @@ -453,7 +453,7 @@ def test_chat_prompt_template_append_and_extend() -> None: message1 = SystemMessage(content="foo") message2 = HumanMessage(content="bar") message3 = HumanMessage(content="baz") - template = ChatPromptTemplate.from_messages([message1]) + template = ChatPromptTemplate([message1]) template.append(message2) template.append(message3) assert len(template) == 3 @@ -480,7 +480,7 @@ def test_convert_to_message_is_strict() -> None: def test_chat_message_partial() -> None: - template = ChatPromptTemplate.from_messages( + template = ChatPromptTemplate( [ ("system", "You are an AI assistant named {name}."), ("human", "Hi I'm {user}"), @@ -734,14 +734,14 @@ def test_messages_placeholder_with_max() -> None: def test_chat_prompt_message_placeholder_partial() -> None: - prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) + prompt = ChatPromptTemplate([MessagesPlaceholder("history")]) prompt = prompt.partial(history=[("system", "foo")]) assert prompt.format_messages() == [SystemMessage(content="foo")] assert prompt.format_messages(history=[("system", "bar")]) == [ SystemMessage(content="bar") ] - prompt = ChatPromptTemplate.from_messages( + prompt = ChatPromptTemplate( [ MessagesPlaceholder("history", optional=True), ] @@ -752,7 +752,7 @@ def test_chat_prompt_message_placeholder_partial() -> None: def test_chat_prompt_message_placeholder_tuple() -> None: - prompt = ChatPromptTemplate.from_messages([("placeholder", "{convo}")]) + prompt = ChatPromptTemplate([("placeholder", "{convo}")]) assert prompt.format_messages(convo=[("user", "foo")]) == [ HumanMessage(content="foo") ] @@ -760,9 +760,7 @@ def test_chat_prompt_message_placeholder_tuple() -> None: assert prompt.format_messages() == [] # Is optional = True - optional_prompt = ChatPromptTemplate.from_messages( - [("placeholder", ["{convo}", False])] - ) + optional_prompt = ChatPromptTemplate([("placeholder", ["{convo}", False])]) assert optional_prompt.format_messages(convo=[("user", "foo")]) == [ HumanMessage(content="foo") ] @@ -771,7 +769,7 @@ def test_chat_prompt_message_placeholder_tuple() -> None: async def test_messages_prompt_accepts_list() -> None: - prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")]) + prompt = ChatPromptTemplate([MessagesPlaceholder("history")]) value = prompt.invoke([("user", "Hi there")]) # type: ignore assert value.to_messages() == [HumanMessage(content="Hi there")] @@ -779,7 +777,7 @@ async def test_messages_prompt_accepts_list() -> None: assert value.to_messages() == [HumanMessage(content="Hi there")] # Assert still raises a nice error - prompt = ChatPromptTemplate.from_messages( + prompt = ChatPromptTemplate( [("system", "You are a {foo}"), MessagesPlaceholder("history")] ) with pytest.raises(TypeError): @@ -790,7 +788,7 @@ async def test_messages_prompt_accepts_list() -> None: def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: - prompt_all_required = ChatPromptTemplate.from_messages( + prompt_all_required = ChatPromptTemplate( messages=[MessagesPlaceholder("history", optional=False), ("user", "${input}")] ) assert set(prompt_all_required.input_variables) == {"input", "history"} @@ -798,7 +796,7 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None: with pytest.raises(ValidationError): prompt_all_required.input_schema(input="") assert prompt_all_required.input_schema.schema() == snapshot(name="required") - prompt_optional = ChatPromptTemplate.from_messages( + prompt_optional = ChatPromptTemplate( messages=[MessagesPlaceholder("history", optional=True), ("user", "${input}")] ) # input variables only lists required variables