From 39b19cf76401a4c9e6c8e59e4790fd3b019f41fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Deschamps?= <44435634+thdesc@users.noreply.github.com> Date: Wed, 3 Jul 2024 11:58:42 -0700 Subject: [PATCH] core[patch]: extract input variables for `path` and `detail` keys in order to format an `ImagePromptTemplate` (#22613) - Description: Add support for `path` and `detail` keys in `ImagePromptTemplate`. Previously, only variables associated with the `url` key were considered. This PR allows for the inclusion of a local image path and a detail parameter as input to the format method. - Issues: - fixes #20820 - related to #22024 - Dependencies: None - Twitter handle: @DeschampsTho5 --------- Co-authored-by: tdeschamps Co-authored-by: Eugene Yurtsev Co-authored-by: Eugene Yurtsev --- libs/core/langchain_core/prompts/chat.py | 16 ++--- .../tests/unit_tests/prompts/test_chat.py | 62 +++++++++++++++++++ 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index f0937f1a3cc..9d27aedce8d 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -473,6 +473,7 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): ) elif isinstance(tmpl, dict) and "image_url" in tmpl: img_template = cast(_ImageTemplateParam, tmpl)["image_url"] + input_variables = [] if isinstance(img_template, str): vars = get_template_variables(img_template, "f-string") if vars: @@ -483,20 +484,19 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): f"\nFrom: {tmpl}" ) input_variables = [vars[0]] - else: - input_variables = None img_template = {"url": img_template} img_template_obj = ImagePromptTemplate( input_variables=input_variables, template=img_template ) elif isinstance(img_template, dict): img_template = dict(img_template) - if "url" in img_template: - input_variables = get_template_variables( - img_template["url"], "f-string" - ) - else: - input_variables = None + for key in ["url", "path", "detail"]: + if key in img_template: + input_variables.extend( + get_template_variables( + img_template[key], "f-string" + ) + ) img_template_obj = ImagePromptTemplate( input_variables=input_variables, template=img_template ) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 5811480f045..86f1cc1954b 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,3 +1,5 @@ +import base64 +import tempfile from pathlib import Path from typing import Any, List, Union @@ -559,6 +561,7 @@ async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: async def test_chat_tmpl_from_messages_multipart_image() -> None: + """Test multipart image URL formatting.""" base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" template = ChatPromptTemplate.from_messages( @@ -641,6 +644,65 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: assert messages == expected +async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None: + """Verify that we can pass `path` for an image as a variable.""" + in_mem = "base64mem" + in_file_data = "base64file01" + + with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file: + temp_file.write(base64.b64decode(in_file_data)) + temp_file.flush() + + template = ChatPromptTemplate.from_messages( + [ + ("system", "You are an AI assistant named {name}."), + ( + "human", + [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{in_mem}", + }, + { + "type": "image_url", + "image_url": {"path": "{file_path}"}, + }, + ], + ), + ] + ) + expected = [ + SystemMessage(content="You are an AI assistant named R2D2."), + HumanMessage( + content=[ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{in_mem}"}, + }, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{in_file_data}"}, + }, + ] + ), + ] + messages = template.format_messages( + name="R2D2", + in_mem=in_mem, + file_path=temp_file.name, + ) + assert messages == expected + + messages = await template.aformat_messages( + name="R2D2", + in_mem=in_mem, + file_path=temp_file.name, + ) + assert messages == expected + + def test_messages_placeholder() -> None: prompt = MessagesPlaceholder("history") with pytest.raises(KeyError):