Compare commits

...

15 Commits

Author SHA1 Message Date
Erick Friis
a11f6118a5 x 2024-08-20 10:28:32 -07:00
Erick Friis
b11af4551a x 2024-08-20 10:20:42 -07:00
Erick Friis
c867b551eb Merge branch 'master' into bagatur/content_block_template 2024-08-20 10:19:06 -07:00
Bagatur
b9b99ea182 Merge branch 'master' into bagatur/content_block_template 2024-08-19 16:06:37 -07:00
Bagatur
27ce9050f4 fmt 2024-08-19 16:06:14 -07:00
Bagatur
b5ddd2ea23 fmt 2024-08-19 16:02:17 -07:00
Bagatur
1af63cc5b3 Merge branch 'master' into bagatur/content_block_template 2024-08-19 15:27:50 -07:00
Bagatur
222ef967fa fmt 2024-08-19 15:08:41 -07:00
Bagatur
cb168ef981 fmt 2024-08-19 15:02:52 -07:00
Bagatur
644f338a10 fmt 2024-08-19 14:10:48 -07:00
Bagatur
d07dde3463 fmt 2024-08-19 13:43:21 -07:00
Bagatur
6758237b12 fmt 2024-08-19 13:42:41 -07:00
Bagatur
6fe0b99a81 fmt 2024-08-19 12:56:21 -07:00
Bagatur
1b86d9998d Merge branch 'master' into bagatur/content_block_template 2024-08-19 12:03:15 -07:00
Bagatur
a83845bb1f rfc: content block prompt template 2024-08-16 17:37:44 -07:00
8 changed files with 392 additions and 65 deletions

View File

@@ -102,6 +102,13 @@ def test_serializable_mapping() -> None:
"modifier",
"RemoveMessage",
),
# This is not exported from langchain, only langchain_core
("langchain_core", "prompts", "content_block", "ContentBlockPromptTemplate"): (
"langchain_core",
"prompts",
"content_block",
"ContentBlockPromptTemplate",
),
}
serializable_modules = import_all_modules("langchain")

View File

@@ -147,6 +147,12 @@ SERIALIZABLE_MAPPING: Dict[Tuple[str, ...], Tuple[str, ...]] = {
"image",
"ImagePromptTemplate",
),
("langchain_core", "prompts", "content_block", "ContentBlockPromptTemplate"): (
"langchain_core",
"prompts",
"content_block",
"ContentBlockPromptTemplate",
),
("langchain", "schema", "agent", "AgentActionMessageLog"): (
"langchain_core",
"agents",

View File

@@ -29,6 +29,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config
from langchain_core.runnables.utils import create_model
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from langchain_core.documents import Document
@@ -352,6 +353,21 @@ class BasePromptTemplate(
else:
raise ValueError(f"{save_path} must be json or yaml")
def pretty_repr(self, html: bool = False) -> str:
"""Get a pretty representation of the prompt.
Args:
html: Whether to return an HTML-formatted string.
Returns:
A pretty representation of the prompt.
"""
raise NotImplementedError()
def pretty_print(self) -> None:
"""Print a pretty representation of the prompt."""
print(self.pretty_repr(html=is_interactive_env())) # noqa: T201
def _get_document_info(doc: Document, prompt: BasePromptTemplate[str]) -> Dict:
base_info = {"page_content": doc.page_content, **doc.metadata}

View File

@@ -33,11 +33,12 @@ from langchain_core.messages import (
convert_to_messages,
)
from langchain_core.messages.base import get_msg_title_repr
from langchain_core.prompt_values import ChatPromptValue, ImageURL, PromptValue
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.content_block import ContentBlockPromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, get_template_variables
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.pydantic_v1 import Field, PositiveInt, root_validator
from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env
@@ -449,7 +450,7 @@ _StringImageMessagePromptTemplateT = TypeVar(
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
text: str
class _ImageTemplateParam(TypedDict, total=False):
@@ -459,12 +460,13 @@ class _ImageTemplateParam(TypedDict, total=False):
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
prompt: Union[StringPromptTemplate, List[BasePromptTemplate]]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
name: Optional[str] = None
"""An optional name for the participant. Provides the model information to
differentiate between participants of the same role."""
_msg_class: Type[BaseMessage]
@@ -503,61 +505,59 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
template_format=template_format,
partial_variables=partial_variables,
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
if (partial_variables is not None) and len(partial_variables) > 0:
raise ValueError(
"Partial variables are not supported for list of templates."
)
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment]
for i, tmpl in enumerate(template):
if isinstance(tmpl, str):
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
tmpl, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
input_variables = []
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {tmpl}"
)
input_variables = [vars[0]]
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
# For backwards compatible ser/des.
# TODO: Refactor in 1.0 so this just uses a ContentBlockPromptTemplate.
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
"type",
"text",
}:
prompt.append(
PromptTemplate.from_template(
cast(_TextTemplateParam, tmpl)["text"],
template_format=template_format,
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
for key in ["url", "path", "detail"]:
if key in img_template:
input_variables.extend(
get_template_variables(
img_template[key], "f-string"
)
)
img_template_obj = ImagePromptTemplate(
input_variables=input_variables, template=img_template
)
# For backwards compatible ser/des.
# TODO: Refactor in 1.0 so this just use a ContentBlockPromptTemplate.
elif isinstance(tmpl, dict) and set(tmpl.keys()) <= {
"type",
"image_url",
}:
prompt.append(
ImagePromptTemplate(
template=cast(dict, tmpl), template_format=template_format
)
else:
raise ValueError()
prompt.append(img_template_obj)
)
elif isinstance(tmpl, dict):
prompt.append(
ContentBlockPromptTemplate(
cast(dict, tmpl), template_format=template_format
)
)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
raise ValueError(
f"Unsupported template type {type(tmpl)} at index {i}. "
f"Expected a list of strings or dicts."
)
else:
raise ValueError()
raise ValueError(
f"Unsupported template type {type(template)}. Expected either a string "
f"or a list (of strings or dicts)."
)
return cls(prompt=prompt, **kwargs)
@classmethod
def from_template_file(
@@ -626,20 +626,31 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
content=text, additional_kwargs=self.additional_kwargs, name=self.name
)
else:
content: List = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
formatted: Union[str, dict] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
formatted = cast(dict, prompt.format(**inputs))
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, ContentBlockPromptTemplate):
formatted = prompt.format(**inputs)
content.append(formatted)
else:
raise ValueError(
f"Unknown prompt type: {type(prompt)}. Expected "
f"StringPromptTemplate, ImagePromptTemplate, or "
f"ContentBlockPromptTemplate."
)
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
content=content,
additional_kwargs=self.additional_kwargs,
name=self.name,
)
async def aformat(self, **kwargs: Any) -> BaseMessage:
@@ -654,20 +665,31 @@ class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
if isinstance(self.prompt, StringPromptTemplate):
text = await self.prompt.aformat(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
content=text, additional_kwargs=self.additional_kwargs, name=self.name
)
else:
content: List = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = await prompt.aformat(**inputs)
formatted: Union[str, dict] = await prompt.aformat(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = await prompt.aformat(**inputs)
formatted = cast(dict, await prompt.aformat(**inputs))
content.append({"type": "image_url", "image_url": formatted})
elif isinstance(prompt, ContentBlockPromptTemplate):
formatted = await prompt.aformat(**inputs)
content.append(formatted)
else:
raise ValueError(
f"Unknown prompt type: {type(prompt)}. Expected "
f"StringPromptTemplate, ImagePromptTemplate, or "
f"ContentBlockPromptTemplate."
)
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
content=content,
additional_kwargs=self.additional_kwargs,
name=self.name,
)
def pretty_repr(self, html: bool = False) -> str:
@@ -697,6 +719,7 @@ class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
_msg_class: Type[BaseMessage] = AIMessage
# TODO: Add support for tool_calls?
@classmethod
def get_lc_namespace(cls) -> List[str]:
@@ -1091,7 +1114,9 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
return values
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> ChatPromptTemplate:
def from_template(
cls, template: str, *, name: Optional[str] = None, **kwargs: Any
) -> ChatPromptTemplate:
"""Create a chat prompt template from a template string.
Creates a chat template consisting of a single message assumed to be from
@@ -1105,7 +1130,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
A new instance of this class.
"""
prompt_template = PromptTemplate.from_template(template, **kwargs)
message = HumanMessagePromptTemplate(prompt=prompt_template)
message = HumanMessagePromptTemplate(prompt=prompt_template, name=name)
return cls.from_messages([message])
@classmethod

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, List
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.image import ImagePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import Field
class ContentBlockPromptTemplate(BasePromptTemplate[Dict[str, Any]]):
"""Template for a single content block."""
template: Dict[str, Any] = Field(default_factory=dict)
"""Template for the content block. Expected to be a dictionary of """
def __init__(
self, template: dict, *, template_format: str = "f-string", **kwargs: Any
) -> None:
input_variables = kwargs.pop("input_variables", [])
if "image_url" in template:
if not isinstance(template["image_url"], BasePromptTemplate):
# For backwards compatibility.
if "type" not in template:
template["type"] = "image_url"
template["image_url"] = ImagePromptTemplate(
template["image_url"], template_format=template_format
)
input_variables += template["image_url"].input_variables
if "text" in template:
if not isinstance(template["text"], PromptTemplate):
template["text"] = PromptTemplate.from_template(
template["text"], template_format=template_format
)
input_variables += template["text"].input_variables
super().__init__(
template=template, input_variables=list(set(input_variables)), **kwargs
) # type: ignore[call-arg]
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "content-block-prompt"
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain_core", "prompts", "content_block"]
def format(self, **kwargs: Any) -> Dict[str, Any]:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted content block as a dict.
"""
formatted = {}
for k, v in self.template.items():
if isinstance(v, BasePromptTemplate):
formatted[k] = v.format(**kwargs)
else:
formatted[k] = v
return formatted
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
"""
raise NotImplementedError(
f"{self.__class__} does not support being directly formatted to a "
f"PromptValue. Can only be formatted as part of a ChatPromptTemplate."
)
def pretty_repr(self, html: bool = False) -> str:
"""Return a pretty representation of the prompt.
Args:
html: Whether to return an html formatted string.
Returns:
A pretty representation of the prompt.
"""
raise NotImplementedError("Not implemented yet.")

View File

@@ -1,7 +1,8 @@
from typing import Any, List
from typing import Any, List, Union
from langchain_core.prompt_values import ImagePromptValue, ImageURL, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.string import get_template_variables
from langchain_core.pydantic_v1 import Field
from langchain_core.runnables import run_in_executor
from langchain_core.utils import image as image_utils
@@ -13,18 +14,46 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
template: dict = Field(default_factory=dict)
"""Template for the prompt."""
def __init__(self, **kwargs: Any) -> None:
if "input_variables" not in kwargs:
kwargs["input_variables"] = []
overlap = set(kwargs["input_variables"]) & set(("url", "path", "detail"))
def __init__(
self,
template: Union[dict, str],
*,
template_format: str = "f-string",
**kwargs: Any,
) -> None:
if isinstance(template, dict) and "image_url" in template:
image_url = template["image_url"]
else:
image_url = template
input_variables = kwargs.pop("input_variables", [])
if isinstance(image_url, str):
vars = get_template_variables(image_url, template_format)
if vars:
if len(vars) > 1:
raise ValueError(
"Only one format variable allowed per image"
f" template.\nGot: {vars}"
f"\nFrom: {template}"
)
input_variables = [vars[0]]
image_url = {"url": image_url}
elif isinstance(image_url, dict):
image_url = dict(image_url)
for key in ["url", "path", "detail"]:
if key in image_url:
input_variables.extend(
get_template_variables(image_url[key], template_format)
)
overlap = set(input_variables) & set(("url", "path", "detail"))
if overlap:
raise ValueError(
"input_variables for the image template cannot contain"
" any of 'url', 'path', or 'detail'."
f" Found: {overlap}"
)
super().__init__(**kwargs)
super().__init__(
template=image_url, input_variables=list(set(input_variables)), **kwargs
) # type: ignore[call-arg]
@property
def _prompt_type(self) -> str:

View File

@@ -2269,6 +2269,118 @@
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain_core',
'prompts',
'content_block',
'ContentBlockPromptTemplate',
]),
'name': 'ContentBlockPromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ContentBlockPromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain_core',
'prompts',
'content_block',
'ContentBlockPromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': dict({
'cache_control': dict({
'type': 'ephemeral',
}),
'text': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'foobar',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
'type': 'text',
}),
}),
'lc': 1,
'name': 'ContentBlockPromptTemplate',
'type': 'constructor',
}),
]),
}),
'lc': 1,

View File

@@ -594,6 +594,7 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
"type": "image_url",
"image_url": {"url": ""},
},
{"image_url": "data:image/jpeg;base64,{my_other_image}"},
],
),
]
@@ -630,6 +631,12 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
"type": "image_url",
"image_url": {"url": ""},
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{other_base64_image}"
},
},
]
),
]
@@ -814,7 +821,7 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
assert load(dumpd(prompt)) == prompt
async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
"""Test chat prompt template ser/des."""
template = ChatPromptTemplate(
[
@@ -854,6 +861,11 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
"image_url": {"url": ""},
},
{"image_url": {"url": ""}},
{
"type": "text",
"text": "foobar",
"cache_control": {"type": "ephemeral"},
},
],
),
),
@@ -863,3 +875,32 @@ async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
)
assert dumpd(template) == snapshot()
assert load(dumpd(template)) == template
def test_chat_content_block() -> None:
template = ChatPromptTemplate(
[
(
"human",
[
{
"type": "text",
"text": "foobar",
"cache_control": {"type": "ephemeral"},
}
],
),
(
"ai",
[{"text": "how are {name}", "cache_control": {"type": "ephemeral"}}],
),
]
)
expected = [
HumanMessage(
[{"type": "text", "text": "foobar", "cache_control": {"type": "ephemeral"}}]
),
AIMessage([{"text": "how are you", "cache_control": {"type": "ephemeral"}}]),
]
actual = template.invoke({"name": "you"}).to_messages()
assert actual == expected