diff --git a/libs/partners/prompty/langchain_prompty/langchain.py b/libs/partners/prompty/langchain_prompty/langchain.py index 51595e6b657..52eac52c0da 100644 --- a/libs/partners/prompty/langchain_prompty/langchain.py +++ b/libs/partners/prompty/langchain_prompty/langchain.py @@ -1,33 +1,34 @@ -from typing import Any, Dict, Literal +from typing import Any, Dict from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import Runnable, RunnableLambda +from .parsers import RoleMap from .utils import load, prepare def create_chat_prompt( path: str, input_name_agent_scratchpad: str = "agent_scratchpad", - template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", ) -> Runnable[Dict[str, Any], ChatPromptTemplate]: """Create a chat prompt from a Langchain schema.""" def runnable_chat_lambda(inputs: Dict[str, Any]) -> ChatPromptTemplate: p = load(path) parsed = prepare(p, inputs) + # Parsed messages have been templated + # Convert to Message objects to avoid templating attempts in ChatPromptTemplate lc_messages = [] for message in parsed: - lc_messages.append((message["role"], message["content"])) + message_class = RoleMap.get_message_class(message["role"]) + lc_messages.append(message_class(content=message["content"])) lc_messages.append( MessagesPlaceholder( variable_name=input_name_agent_scratchpad, optional=True ) # type: ignore[arg-type] ) - lc_p = ChatPromptTemplate.from_messages( - lc_messages, template_format=template_format - ) + lc_p = ChatPromptTemplate.from_messages(lc_messages) lc_p = lc_p.partial(**p.inputs) return lc_p diff --git a/libs/partners/prompty/langchain_prompty/parsers.py b/libs/partners/prompty/langchain_prompty/parsers.py index 5f596a662fe..592ddd80136 100644 --- a/libs/partners/prompty/langchain_prompty/parsers.py +++ b/libs/partners/prompty/langchain_prompty/parsers.py @@ -1,18 +1,41 @@ import base64 import re -from typing import List, Union +from typing import Dict, List, Type, Union +from langchain_core.messages import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) from pydantic import BaseModel from .core import Invoker, Prompty, SimpleModel +class RoleMap: + _ROLE_MAP: Dict[str, Type[BaseMessage]] = { + "system": SystemMessage, + "user": HumanMessage, + "human": HumanMessage, + "assistant": AIMessage, + "ai": AIMessage, + "function": FunctionMessage, + } + ROLES = _ROLE_MAP.keys() + + @classmethod + def get_message_class(cls, role: str) -> Type[BaseMessage]: + return cls._ROLE_MAP[role] + + class PromptyChatParser(Invoker): """Parse a chat prompt into a list of messages.""" def __init__(self, prompty: Prompty) -> None: self.prompty = prompty - self.roles = ["assistant", "function", "system", "user", "human", "ai"] + self.roles = RoleMap.ROLES self.path = self.prompty.file.parent def inline_image(self, image_item: str) -> str: diff --git a/libs/partners/prompty/tests/unit_tests/prompts/double_templating.prompty b/libs/partners/prompty/tests/unit_tests/prompts/double_templating.prompty new file mode 100644 index 00000000000..f5fbc610c97 --- /dev/null +++ b/libs/partners/prompty/tests/unit_tests/prompts/double_templating.prompty @@ -0,0 +1,10 @@ +--- +name: IssuePrompt +description: A prompt used to detect if double templating occurs +model: + api: chat +template: mustache +--- + +user: +{{user_input}} \ No newline at end of file diff --git a/libs/partners/prompty/tests/unit_tests/test_templating.py b/libs/partners/prompty/tests/unit_tests/test_templating.py new file mode 100644 index 00000000000..ba95d3ddb4d --- /dev/null +++ b/libs/partners/prompty/tests/unit_tests/test_templating.py @@ -0,0 +1,23 @@ +from pathlib import Path + +import pytest + +from langchain_prompty import create_chat_prompt + +PROMPT_DIR = Path(__file__).parent / "prompts" + + +def test_double_templating() -> None: + """ + Assess whether double templating occurs when invoking a chat prompt. + If it does, an error is thrown and the test fails. + """ + + prompt_path = PROMPT_DIR / "double_templating.prompty" + templated_prompt = create_chat_prompt(str(prompt_path)) + query = "What do you think of this JSON object: {'key': 7}?" + + try: + templated_prompt.invoke(input={"user_input": query}) + except KeyError as e: + pytest.fail("Double templating occurred: " + str(e))