core[minor]: Image prompt template (#14263)

Builds on Bagatur's (#13227). See unit test for example usage (below)

```python
def test_chat_tmpl_from_messages_multipart_image() -> None:
    base64_image = "abcd123"
    other_base64_image = "abcd123"
    template = ChatPromptTemplate.from_messages(
        [
            ("system", "You are an AI assistant named {name}."),
            (
                "human",
                [
                    {"type": "text", "text": "What's in this image?"},
                    # OAI supports all these structures today
                    {
                        "type": "image_url",
                        "image_url": "data:image/jpeg;base64,{my_image}",
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "data:image/jpeg;base64,{my_image}"},
                    },
                    {"type": "image_url", "image_url": "{my_other_image}"},
                    {
                        "type": "image_url",
                        "image_url": {"url": "{my_other_image}", "detail": "medium"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": "https://www.langchain.com/image.png"},
                    },
                    {
                        "type": "image_url",
                        "image_url": {"url": ""},
                    },
                ],
            ),
        ]
    )
    messages = template.format_messages(
        name="R2D2", my_image=base64_image, my_other_image=other_base64_image
    )
    expected = [
        SystemMessage(content="You are an AI assistant named R2D2."),
        HumanMessage(
            content=[
                {"type": "text", "text": "What's in this image?"},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"data:image/jpeg;base64,{other_base64_image}"
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": f"{other_base64_image}"},
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": f"{other_base64_image}",
                        "detail": "medium",
                    },
                },
                {
                    "type": "image_url",
                    "image_url": {"url": "https://www.langchain.com/image.png"},
                },
                {
                    "type": "image_url",
                    "image_url": {"url": ""},
                },
            ]
        ),
    ]
    assert messages == expected
```

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Brace Sproul <braceasproul@gmail.com>
This commit is contained in:
William FH 2024-01-27 17:04:29 -08:00 committed by GitHub
parent 3c387bc12d
commit 38425c99d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 454 additions and 49 deletions

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Literal, Sequence from typing import List, Literal, Sequence
from typing_extensions import TypedDict
from langchain_core.load.serializable import Serializable from langchain_core.load.serializable import Serializable
from langchain_core.messages import ( from langchain_core.messages import (
AnyMessage, AnyMessage,
@ -82,6 +84,30 @@ class ChatPromptValue(PromptValue):
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
class ImageURL(TypedDict, total=False):
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""
url: str
"""Either a URL of the image or the base64 encoded image data."""
class ImagePromptValue(PromptValue):
"""Image prompt value."""
image_url: ImageURL
"""Prompt image."""
type: Literal["ImagePromptValue"] = "ImagePromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.image_url["url"]
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=[self.image_url])]
class ChatPromptValueConcrete(ChatPromptValue): class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts. """Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas.""" For use in external schemas."""

View File

@ -8,10 +8,12 @@ from typing import (
Any, Any,
Callable, Callable,
Dict, Dict,
Generic,
List, List,
Mapping, Mapping,
Optional, Optional,
Type, Type,
TypeVar,
Union, Union,
) )
@ -30,7 +32,12 @@ if TYPE_CHECKING:
from langchain_core.documents import Document from langchain_core.documents import Document
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC): FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt.""" """Base class for all prompt templates, returning a prompt."""
input_variables: List[str] input_variables: List[str]
@ -142,7 +149,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return {**partial_kwargs, **kwargs} return {**partial_kwargs, **kwargs}
@abstractmethod @abstractmethod
def format(self, **kwargs: Any) -> str: def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs. """Format the prompt with the inputs.
Args: Args:
@ -210,7 +217,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
raise ValueError(f"{save_path} must be json or yaml") raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str: def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template. """Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources: First, this pulls information from the document from two sources:
@ -236,7 +243,7 @@ def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_core import Document from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
doc = Document(page_content="This is a joke", metadata={"page": "1"}) doc = Document(page_content="This is a joke", metadata={"page": "1"})

View File

@ -13,8 +13,10 @@ from typing import (
Set, Set,
Tuple, Tuple,
Type, Type,
TypedDict,
TypeVar, TypeVar,
Union, Union,
cast,
overload, overload,
) )
@ -30,10 +32,11 @@ from langchain_core.messages import (
convert_to_messages, convert_to_messages,
) )
from langchain_core.messages.base import get_msg_title_repr from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, PromptValue from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.pydantic_v1 import Field, root_validator from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_colored_text from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
@ -288,14 +291,153 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
) )
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate): _StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user.""" """Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
_msg_class: Type[BaseMessage]
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {tmpl}"
)
input_variables = [vars[0]]
else:
input_variables = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
if "url" in img_template:
input_variables = get_template_variables(
img_template["url"], "f-string"
)
else:
input_variables = None
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
else:
raise ValueError()
prompt.append(img_template_obj)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()
@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables
def format(self, **kwargs: Any) -> BaseMessage: def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template. """Format the prompt template.
@ -305,53 +447,55 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
Returns: Returns:
Formatted message. Formatted message.
""" """
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs) text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs) return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate): class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
_msg_class: Type[BaseMessage] = HumanMessage
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI.""" """AI message prompt template. This is a message sent from the AI."""
_msg_class: Type[BaseMessage] = AIMessage
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args: class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
"""System message prompt template. """System message prompt template.
This is a message that is not sent to the user. This is a message that is not sent to the user.
""" """
_msg_class: Type[BaseMessage] = SystemMessage
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "prompts", "chat"] return ["langchain", "prompts", "chat"]
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class BaseChatPromptTemplate(BasePromptTemplate, ABC): class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates.""" """Base class for chat prompt templates."""
@ -405,8 +549,7 @@ MessageLike = Union[BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTempla
MessageLikeRepresentation = Union[ MessageLikeRepresentation = Union[
MessageLike, MessageLike,
Tuple[str, str], Tuple[Union[str, Type], Union[str, List[dict], List[object]]],
Tuple[Type, str],
str, str,
] ]
@ -738,7 +881,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type( def _create_template_from_message_type(
message_type: str, template: str message_type: str, template: Union[str, list]
) -> BaseMessagePromptTemplate: ) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string. """Create a message prompt template from a message type and template string.
@ -754,9 +897,9 @@ def _create_template_from_message_type(
template template
) )
elif message_type in ("ai", "assistant"): elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template) message = AIMessagePromptTemplate.from_template(cast(str, template))
elif message_type == "system": elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template) message = SystemMessagePromptTemplate.from_template(cast(str, template))
else: else:
raise ValueError( raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human'," f"Unexpected message type: {message_type}. Use one of 'human',"
@ -799,7 +942,9 @@ def _convert_to_message(
if isinstance(message_type_str, str): if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template) _message = _create_template_from_message_type(message_type_str, template)
else: else:
_message = message_type_str(prompt=PromptTemplate.from_template(template)) _message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
)
else: else:
raise NotImplementedError(f"Unsupported message type: {type(message)}") raise NotImplementedError(f"Unsupported message type: {type(message)}")

View File

@ -0,0 +1,76 @@
from typing import Any
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import Field
from langchain_core.utils import image as image_utils
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""An image prompt template for a multimodal model."""
template: dict = Field(default_factory=dict)
"""Template for the prompt."""
def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
if overlap:
raise ValueError(
"input_variables for the image template cannot contain"
" any of 'url', 'path', or 'detail'."
f" Found: {overlap}"
)
super().__init__(**kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "image-prompt"
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs))
def format(
self,
**kwargs: Any,
) -> ImageURL:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
formatted = {}
for k, v in self.template.items():
if isinstance(v, str):
formatted[k] = v.format(**kwargs)
else:
formatted[k] = v
url = kwargs.get("url") or formatted.get("url")
path = kwargs.get("path") or formatted.get("path")
detail = kwargs.get("detail") or formatted.get("detail")
if not url and not path:
raise ValueError("Must provide either url or path.")
if not url:
if not isinstance(path, str):
raise ValueError("path must be a string.")
url = image_utils.image_to_data_url(path)
if not isinstance(url, str):
raise ValueError("url must be a string.")
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail # type: ignore[typeddict-item]
return output

View File

@ -4,6 +4,7 @@
These functions do not depend on any other LangChain module. These functions do not depend on any other LangChain module.
""" """
from langchain_core.utils import image
from langchain_core.utils.env import get_from_dict_or_env, get_from_env from langchain_core.utils.env import get_from_dict_or_env, get_from_env
from langchain_core.utils.formatting import StrictFormatter, formatter from langchain_core.utils.formatting import StrictFormatter, formatter
from langchain_core.utils.input import ( from langchain_core.utils.input import (
@ -41,6 +42,7 @@ __all__ = [
"xor_args", "xor_args",
"try_load_from_hub", "try_load_from_hub",
"build_extra_kwargs", "build_extra_kwargs",
"image",
"get_from_env", "get_from_env",
"get_from_dict_or_env", "get_from_dict_or_env",
"stringify_dict", "stringify_dict",

View File

@ -0,0 +1,14 @@
import base64
import mimetypes
def encode_image(image_path: str) -> str:
"""Get base64 string from image URI."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str:
encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}"

View File

@ -3,6 +3,9 @@ from typing import Any, List, Union
import pytest import pytest
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -243,6 +246,7 @@ def test_chat_valid_infer_variables() -> None:
def test_chat_from_role_strings() -> None: def test_chat_from_role_strings() -> None:
"""Test instantiation of chat template from role strings.""" """Test instantiation of chat template from role strings."""
with pytest.warns(LangChainPendingDeprecationWarning):
template = ChatPromptTemplate.from_role_strings( template = ChatPromptTemplate.from_role_strings(
[ [
("system", "You are a bot."), ("system", "You are a bot."),
@ -363,6 +367,136 @@ def test_chat_message_partial() -> None:
assert template2.format(input="hello") == get_buffer_string(expected) assert template2.format(input="hello") == get_buffer_string(expected)
def test_chat_tmpl_from_messages_multipart_text() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
],
),
]
)
messages = template.format_messages(name="R2D2")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
]
),
]
assert messages == expected
def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this {object_name}?"},
{"type": "text", "text": "Oh nvm"},
],
),
]
)
messages = template.format_messages(name="R2D2", object_name="image")
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "text", "text": "Oh nvm"},
]
),
]
assert messages == expected
def test_chat_tmpl_from_messages_multipart_image() -> None:
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
(
"human",
[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{my_image}",
},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
},
{"type": "image_url", "image_url": "{my_other_image}"},
{
"type": "image_url",
"image_url": {"url": "{my_other_image}", "detail": "medium"},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
],
),
]
)
messages = template.format_messages(
name="R2D2", my_image=base64_image, my_other_image=other_base64_image
)
expected = [
SystemMessage(content="You are an AI assistant named R2D2."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{other_base64_image}"
},
},
{
"type": "image_url",
"image_url": {"url": f"{other_base64_image}"},
},
{
"type": "image_url",
"image_url": {
"url": f"{other_base64_image}",
"detail": "medium",
},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
]
),
]
assert messages == expected
def test_messages_placeholder() -> None: def test_messages_placeholder() -> None:
prompt = MessagesPlaceholder("history") prompt = MessagesPlaceholder("history")
with pytest.raises(KeyError): with pytest.raises(KeyError):

View File

@ -16,6 +16,7 @@ EXPECTED_ALL = [
"xor_args", "xor_args",
"try_load_from_hub", "try_load_from_hub",
"build_extra_kwargs", "build_extra_kwargs",
"image",
"get_from_dict_or_env", "get_from_dict_or_env",
"get_from_env", "get_from_env",
"stringify_dict", "stringify_dict",