From 58a88f39119cbfbeca90bd65b5e3fbc8498b6579 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 4 Oct 2023 18:54:53 +0100 Subject: [PATCH] Add optional input_types to prompt template (#11385) - default MessagesPlaceholder one to list of messages --- libs/langchain/langchain/prompts/chat.py | 5 + .../langchain/schema/prompt_template.py | 6 +- libs/langchain/pyproject.toml | 2 +- .../runnable/__snapshots__/test_runnable.ambr | 262 +++++++++++++++++- .../schema/runnable/test_runnable.py | 188 ++++++++++++- 5 files changed, 449 insertions(+), 14 deletions(-) diff --git a/libs/langchain/langchain/prompts/chat.py b/libs/langchain/langchain/prompts/chat.py index bffb0aaa7b8..b93fb028bdf 100644 --- a/libs/langchain/langchain/prompts/chat.py +++ b/libs/langchain/langchain/prompts/chat.py @@ -420,9 +420,13 @@ class ChatPromptTemplate(BaseChatPromptTemplate): """ messages = values["messages"] input_vars = set() + input_types: Dict[str, Any] = values.get("input_types", {}) for message in messages: if isinstance(message, (BaseMessagePromptTemplate, BaseChatPromptTemplate)): input_vars.update(message.input_variables) + if isinstance(message, MessagesPlaceholder): + if message.variable_name not in input_types: + input_types[message.variable_name] = List[AnyMessage] if "partial_variables" in values: input_vars = input_vars - set(values["partial_variables"]) if "input_variables" in values: @@ -434,6 +438,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate): ) else: values["input_variables"] = sorted(input_vars) + values["input_types"] = input_types return values @classmethod diff --git a/libs/langchain/langchain/schema/prompt_template.py b/libs/langchain/langchain/schema/prompt_template.py index b72e2fe55ec..c9dabd572ac 100644 --- a/libs/langchain/langchain/schema/prompt_template.py +++ b/libs/langchain/langchain/schema/prompt_template.py @@ -19,6 +19,9 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): input_variables: List[str] """A list of the names of the variables the prompt template expects.""" + input_types: Dict[str, Any] = Field(default_factory=dict) + """A dictionary of the types of the variables the prompt template expects. + If not provided, all variables are assumed to be strings.""" output_parser: Optional[BaseOutputParser] = None """How to parse the output of calling an LLM on this formatted prompt.""" partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field( @@ -46,7 +49,8 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): def input_schema(self) -> type[BaseModel]: # This is correct, but pydantic typings/mypy don't think so. return create_model( # type: ignore[call-overload] - "PromptInput", **{k: (Any, None) for k in self.input_variables} + "PromptInput", + **{k: (self.input_types.get(k, str), None) for k in self.input_variables}, ) def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue: diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index b269ce228df..f862b2e6add 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -392,7 +392,7 @@ build-backend = "poetry.core.masonry.api" # # https://github.com/tophat/syrupy # --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite. -addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused" +addopts = "--strict-markers --strict-config --durations=5 --snapshot-warn-unused -vv" # Registering custom markers. # https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers markers = [ diff --git a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr index 4e7b4c53de9..7fc8ba8c8d5 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr +++ b/libs/langchain/tests/unit_tests/schema/runnable/__snapshots__/test_runnable.ambr @@ -2676,10 +2676,268 @@ 'type': 'object', }), }), - 'title': 'PromptTemplateOutput', + 'title': 'ChatPromptTemplateOutput', }) # --- # name: test_schemas.4 + dict({ + 'anyOf': list([ + dict({ + '$ref': '#/definitions/StringPromptValue', + }), + dict({ + '$ref': '#/definitions/ChatPromptValueConcrete', + }), + ]), + 'definitions': dict({ + 'AIMessage': dict({ + 'description': 'A Message from an AI.', + 'properties': dict({ + 'additional_kwargs': dict({ + 'title': 'Additional Kwargs', + 'type': 'object', + }), + 'content': dict({ + 'title': 'Content', + 'type': 'string', + }), + 'example': dict({ + 'default': False, + 'title': 'Example', + 'type': 'boolean', + }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), + 'type': dict({ + 'default': 'ai', + 'enum': list([ + 'ai', + ]), + 'title': 'Type', + 'type': 'string', + }), + }), + 'required': list([ + 'content', + ]), + 'title': 'AIMessage', + 'type': 'object', + }), + 'ChatMessage': dict({ + 'description': 'A Message that can be assigned an arbitrary speaker (i.e. role).', + 'properties': dict({ + 'additional_kwargs': dict({ + 'title': 'Additional Kwargs', + 'type': 'object', + }), + 'content': dict({ + 'title': 'Content', + 'type': 'string', + }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), + 'role': dict({ + 'title': 'Role', + 'type': 'string', + }), + 'type': dict({ + 'default': 'chat', + 'enum': list([ + 'chat', + ]), + 'title': 'Type', + 'type': 'string', + }), + }), + 'required': list([ + 'content', + 'role', + ]), + 'title': 'ChatMessage', + 'type': 'object', + }), + 'ChatPromptValueConcrete': dict({ + 'description': ''' + Chat prompt value which explicitly lists out the message types it accepts. + For use in external schemas. + ''', + 'properties': dict({ + 'messages': dict({ + 'items': dict({ + 'anyOf': list([ + dict({ + '$ref': '#/definitions/AIMessage', + }), + dict({ + '$ref': '#/definitions/HumanMessage', + }), + dict({ + '$ref': '#/definitions/ChatMessage', + }), + dict({ + '$ref': '#/definitions/SystemMessage', + }), + dict({ + '$ref': '#/definitions/FunctionMessage', + }), + ]), + }), + 'title': 'Messages', + 'type': 'array', + }), + }), + 'required': list([ + 'messages', + ]), + 'title': 'ChatPromptValueConcrete', + 'type': 'object', + }), + 'FunctionMessage': dict({ + 'description': 'A Message for passing the result of executing a function back to a model.', + 'properties': dict({ + 'additional_kwargs': dict({ + 'title': 'Additional Kwargs', + 'type': 'object', + }), + 'content': dict({ + 'title': 'Content', + 'type': 'string', + }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), + 'name': dict({ + 'title': 'Name', + 'type': 'string', + }), + 'type': dict({ + 'default': 'function', + 'enum': list([ + 'function', + ]), + 'title': 'Type', + 'type': 'string', + }), + }), + 'required': list([ + 'content', + 'name', + ]), + 'title': 'FunctionMessage', + 'type': 'object', + }), + 'HumanMessage': dict({ + 'description': 'A Message from a human.', + 'properties': dict({ + 'additional_kwargs': dict({ + 'title': 'Additional Kwargs', + 'type': 'object', + }), + 'content': dict({ + 'title': 'Content', + 'type': 'string', + }), + 'example': dict({ + 'default': False, + 'title': 'Example', + 'type': 'boolean', + }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), + 'type': dict({ + 'default': 'human', + 'enum': list([ + 'human', + ]), + 'title': 'Type', + 'type': 'string', + }), + }), + 'required': list([ + 'content', + ]), + 'title': 'HumanMessage', + 'type': 'object', + }), + 'StringPromptValue': dict({ + 'description': 'String prompt value.', + 'properties': dict({ + 'text': dict({ + 'title': 'Text', + 'type': 'string', + }), + }), + 'required': list([ + 'text', + ]), + 'title': 'StringPromptValue', + 'type': 'object', + }), + 'SystemMessage': dict({ + 'description': ''' + A Message for priming AI behavior, usually passed in as the first of a sequence + of input messages. + ''', + 'properties': dict({ + 'additional_kwargs': dict({ + 'title': 'Additional Kwargs', + 'type': 'object', + }), + 'content': dict({ + 'title': 'Content', + 'type': 'string', + }), + 'is_chunk': dict({ + 'default': False, + 'enum': list([ + False, + ]), + 'title': 'Is Chunk', + 'type': 'boolean', + }), + 'type': dict({ + 'default': 'system', + 'enum': list([ + 'system', + ]), + 'title': 'Type', + 'type': 'string', + }), + }), + 'required': list([ + 'content', + ]), + 'title': 'SystemMessage', + 'type': 'object', + }), + }), + 'title': 'PromptTemplateOutput', + }) +# --- +# name: test_schemas.5 dict({ 'definitions': dict({ 'AIMessage': dict({ @@ -2944,7 +3202,7 @@ 'type': 'array', }) # --- -# name: test_schemas.5 +# name: test_schemas.6 dict({ 'anyOf': list([ dict({ diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 1fae13287f9..875e6965bf5 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -35,6 +35,7 @@ from langchain.prompts.chat import ( ChatPromptTemplate, ChatPromptValue, HumanMessagePromptTemplate, + MessagesPlaceholder, SystemMessagePromptTemplate, ) from langchain.schema.document import Document @@ -241,12 +242,179 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: assert fake_chat.input_schema.schema() == snapshot assert fake_chat.output_schema.schema() == snapshot + chat_prompt = ChatPromptTemplate.from_messages( + [ + MessagesPlaceholder(variable_name="history"), + ("human", "Hello, how are you?"), + ] + ) + + assert chat_prompt.input_schema.schema() == { + "title": "PromptInput", + "type": "object", + "properties": { + "history": { + "title": "History", + "type": "array", + "items": { + "anyOf": [ + {"$ref": "#/definitions/AIMessage"}, + {"$ref": "#/definitions/HumanMessage"}, + {"$ref": "#/definitions/ChatMessage"}, + {"$ref": "#/definitions/SystemMessage"}, + {"$ref": "#/definitions/FunctionMessage"}, + ] + }, + } + }, + "definitions": { + "AIMessage": { + "title": "AIMessage", + "description": "A Message from an AI.", + "type": "object", + "properties": { + "content": {"title": "Content", "type": "string"}, + "additional_kwargs": { + "title": "Additional Kwargs", + "type": "object", + }, + "type": { + "title": "Type", + "default": "ai", + "enum": ["ai"], + "type": "string", + }, + "example": { + "title": "Example", + "default": False, + "type": "boolean", + }, + "is_chunk": { + "title": "Is Chunk", + "default": False, + "enum": [False], + "type": "boolean", + }, + }, + "required": ["content"], + }, + "HumanMessage": { + "title": "HumanMessage", + "description": "A Message from a human.", + "type": "object", + "properties": { + "content": {"title": "Content", "type": "string"}, + "additional_kwargs": { + "title": "Additional Kwargs", + "type": "object", + }, + "type": { + "title": "Type", + "default": "human", + "enum": ["human"], + "type": "string", + }, + "example": { + "title": "Example", + "default": False, + "type": "boolean", + }, + "is_chunk": { + "title": "Is Chunk", + "default": False, + "enum": [False], + "type": "boolean", + }, + }, + "required": ["content"], + }, + "ChatMessage": { + "title": "ChatMessage", + "description": "A Message that can be assigned an arbitrary speaker (i.e. role).", # noqa: E501 + "type": "object", + "properties": { + "content": {"title": "Content", "type": "string"}, + "additional_kwargs": { + "title": "Additional Kwargs", + "type": "object", + }, + "type": { + "title": "Type", + "default": "chat", + "enum": ["chat"], + "type": "string", + }, + "role": {"title": "Role", "type": "string"}, + "is_chunk": { + "title": "Is Chunk", + "default": False, + "enum": [False], + "type": "boolean", + }, + }, + "required": ["content", "role"], + }, + "SystemMessage": { + "title": "SystemMessage", + "description": "A Message for priming AI behavior, usually passed in as the first of a sequence\nof input messages.", # noqa: E501 + "type": "object", + "properties": { + "content": {"title": "Content", "type": "string"}, + "additional_kwargs": { + "title": "Additional Kwargs", + "type": "object", + }, + "type": { + "title": "Type", + "default": "system", + "enum": ["system"], + "type": "string", + }, + "is_chunk": { + "title": "Is Chunk", + "default": False, + "enum": [False], + "type": "boolean", + }, + }, + "required": ["content"], + }, + "FunctionMessage": { + "title": "FunctionMessage", + "description": "A Message for passing the result of executing a function back to a model.", # noqa: E501 + "type": "object", + "properties": { + "content": {"title": "Content", "type": "string"}, + "additional_kwargs": { + "title": "Additional Kwargs", + "type": "object", + }, + "type": { + "title": "Type", + "default": "function", + "enum": ["function"], + "type": "string", + }, + "name": {"title": "Name", "type": "string"}, + "is_chunk": { + "title": "Is Chunk", + "default": False, + "enum": [False], + "type": "boolean", + }, + }, + "required": ["content", "name"], + }, + }, + } + assert chat_prompt.output_schema.schema() == snapshot + prompt = PromptTemplate.from_template("Hello, {name}!") assert prompt.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"name": {"title": "Name"}}, + "properties": {"name": {"title": "Name", "type": "string"}}, } assert prompt.output_schema.schema() == snapshot @@ -255,7 +423,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: assert prompt_mapper.input_schema.schema() == { "definitions": { "PromptInput": { - "properties": {"name": {"title": "Name"}}, + "properties": {"name": {"title": "Name", "type": "string"}}, "title": "PromptInput", "type": "object", } @@ -280,7 +448,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: assert seq.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"name": {"title": "Name"}}, + "properties": {"name": {"title": "Name", "type": "string"}}, } assert seq.output_schema.schema() == { "type": "array", @@ -320,7 +488,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: assert seq_w_map.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"name": {"title": "Name"}}, + "properties": {"name": {"title": "Name", "type": "string"}}, } assert seq_w_map.output_schema.schema() == { "title": "RunnableMapOutput", @@ -428,7 +596,7 @@ def test_schema_complex_seq() -> None: "title": "RunnableMapInput", "type": "object", "properties": { - "person": {"title": "Person"}, + "person": {"title": "Person", "type": "string"}, "language": {"title": "Language"}, }, } @@ -2318,7 +2486,7 @@ def test_deep_stream_assign() -> None: assert chain_with_assign.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"question": {"title": "Question"}}, + "properties": {"question": {"title": "Question", "type": "string"}}, } assert chain_with_assign.output_schema.schema() == { "title": "RunnableAssignOutput", @@ -2368,7 +2536,7 @@ def test_deep_stream_assign() -> None: assert chain_with_assign_shadow.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"question": {"title": "Question"}}, + "properties": {"question": {"title": "Question", "type": "string"}}, } assert chain_with_assign_shadow.output_schema.schema() == { "title": "RunnableAssignOutput", @@ -2444,7 +2612,7 @@ async def test_deep_astream_assign() -> None: assert chain_with_assign.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"question": {"title": "Question"}}, + "properties": {"question": {"title": "Question", "type": "string"}}, } assert chain_with_assign.output_schema.schema() == { "title": "RunnableAssignOutput", @@ -2494,7 +2662,7 @@ async def test_deep_astream_assign() -> None: assert chain_with_assign_shadow.input_schema.schema() == { "title": "PromptInput", "type": "object", - "properties": {"question": {"title": "Question"}}, + "properties": {"question": {"title": "Question", "type": "string"}}, } assert chain_with_assign_shadow.output_schema.schema() == { "title": "RunnableAssignOutput", @@ -3290,7 +3458,7 @@ async def test_tool_from_runnable() -> None: assert chain_tool.description.endswith(repr(chain)) assert chain_tool.args_schema.schema() == chain.input_schema.schema() assert chain_tool.args_schema.schema() == { - "properties": {"question": {"title": "Question"}}, + "properties": {"question": {"title": "Question", "type": "string"}}, "title": "PromptInput", "type": "object", }