mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 16:36:06 +00:00
core[patch]: rm image prompt file loading
This commit is contained in:
parent
bc6600d86f
commit
7b214ee83d
@ -4,7 +4,6 @@ from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.utils import image as image_utils
|
||||
|
||||
|
||||
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
@ -54,6 +53,11 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the url is not provided.
|
||||
ValueError: If the url is not a string.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
@ -67,23 +71,38 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
else:
|
||||
formatted[k] = v
|
||||
url = kwargs.get("url") or formatted.get("url")
|
||||
path = kwargs.get("path") or formatted.get("path")
|
||||
if kwargs.get("path") or formatted.get("path"):
|
||||
msg = (
|
||||
"Loading images from 'path' has been removed as of 0.3.15 for security "
|
||||
"reasons. Please specify images by 'url'."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
detail = kwargs.get("detail") or formatted.get("detail")
|
||||
if not url and not path:
|
||||
raise ValueError("Must provide either url or path.")
|
||||
if not url:
|
||||
if not isinstance(path, str):
|
||||
raise ValueError("path must be a string.")
|
||||
url = image_utils.image_to_data_url(path)
|
||||
if not isinstance(url, str):
|
||||
raise ValueError("url must be a string.")
|
||||
output: ImageURL = {"url": url}
|
||||
if detail:
|
||||
# Don't check literal values here: let the API check them
|
||||
output["detail"] = detail # type: ignore[typeddict-item]
|
||||
msg = "Must provide url."
|
||||
raise ValueError(msg)
|
||||
elif not isinstance(url, str):
|
||||
msg = "url must be a string."
|
||||
raise ValueError(msg)
|
||||
else:
|
||||
output: ImageURL = {"url": url}
|
||||
if detail:
|
||||
# Don't check literal values here: let the API check them
|
||||
output["detail"] = detail # type: ignore[typeddict-item]
|
||||
return output
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> ImageURL:
|
||||
"""Async format the prompt with the inputs.
|
||||
|
||||
Args:
|
||||
kwargs: Any arguments to be passed to the prompt template.
|
||||
|
||||
Returns:
|
||||
A formatted string.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path or url is not a string.
|
||||
"""
|
||||
return await run_in_executor(None, self.format, **kwargs)
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
|
@ -1,14 +1,8 @@
|
||||
import base64
|
||||
import mimetypes
|
||||
from typing import Any
|
||||
|
||||
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""Get base64 string from image URI."""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
|
||||
def image_to_data_url(image_path: str) -> str:
|
||||
encoding = encode_image(image_path)
|
||||
mime_type = mimetypes.guess_type(image_path)[0]
|
||||
return f"data:{mime_type};base64,{encoding}"
|
||||
def __getattr__(name: str) -> Any:
|
||||
if name in ("encode_image", "image_to_data_url"):
|
||||
msg = f"'{name}' has been removed for security reasons."
|
||||
raise ValueError(msg)
|
||||
raise AttributeError(name)
|
||||
|
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
|
||||
@ -568,6 +570,49 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
|
||||
assert messages == expected
|
||||
|
||||
|
||||
async def test_chat_tmpl_from_messages_multipart_formatting_with_path() -> None:
|
||||
"""Verify that we cannot pass `path` for an image as a variable."""
|
||||
in_mem = "base64mem"
|
||||
in_file_data = "base64file01"
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=True, suffix=".jpg") as temp_file:
|
||||
temp_file.write(base64.b64decode(in_file_data))
|
||||
temp_file.flush()
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", "You are an AI assistant named {name}."),
|
||||
(
|
||||
"human",
|
||||
[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": "data:image/jpeg;base64,{in_mem}",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"path": "{file_path}"},
|
||||
},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
with pytest.raises(ValueError):
|
||||
template.format_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
file_path=temp_file.name,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await template.aformat_messages(
|
||||
name="R2D2",
|
||||
in_mem=in_mem,
|
||||
file_path=temp_file.name,
|
||||
)
|
||||
|
||||
|
||||
def test_messages_placeholder() -> None:
|
||||
prompt = MessagesPlaceholder("history")
|
||||
with pytest.raises(KeyError):
|
||||
|
Loading…
Reference in New Issue
Block a user