Compare commits

...

7 Commits

Author SHA1 Message Date
Bagatur
d1027b1d7f cr 2023-11-10 17:58:43 -08:00
Bagatur
2afa070d34 Merge branch 'master' into bagautr/rfc_image_template 2023-11-10 17:16:10 -08:00
Bagatur
186fb7adaf fmt 2023-11-10 16:00:25 -08:00
Bagatur
07866da9b1 py3.8 compat 2023-11-10 15:44:39 -08:00
Bagatur
f29df0d432 generalize 2023-11-10 15:41:30 -08:00
Bagatur
017fcf9a50 fix 2023-11-10 14:27:53 -08:00
Bagatur
98c103112b rfc 2023-11-10 14:25:43 -08:00
9 changed files with 308 additions and 50 deletions

View File

@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Literal, Set
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.types.image import ImageURL
from langchain.utils.formatting import formatter
@@ -165,9 +166,25 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
class StringPromptTemplate(BasePromptTemplate[str], ABC):
"""String prompt that exposes the format method, returning a prompt."""
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return StringPromptValue(text=self.format(**kwargs))
class ImagePromptValue(PromptValue):
"""Image prompt value."""
image_url: ImageURL
"""Prompt image."""
type: Literal["ImagePromptValue"] = "ImagePromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.image_url["url"]
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=[self.image_url])]

View File

@@ -15,12 +15,16 @@ from typing import (
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypedDict
from langchain._api import deprecated
from langchain.load.serializable import Serializable
from langchain.prompts.base import StringPromptTemplate
from langchain.prompts.base import StringPromptTemplate, get_template_variables
from langchain.prompts.image import ImagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema import (
@@ -36,6 +40,7 @@ from langchain.schema.messages import (
SystemMessage,
get_buffer_string,
)
from langchain.types.image import ImageURL
class BaseMessagePromptTemplate(Serializable, ABC):
@@ -118,8 +123,8 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
return [self.variable_name]
MessagePromptTemplateT = TypeVar(
"MessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
_StringMessagePromptTemplateT = TypeVar(
"_StringMessagePromptTemplateT", bound="BaseStringMessagePromptTemplate"
)
"""Type variable for message prompt templates."""
@@ -134,11 +139,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod
def from_template(
cls: Type[MessagePromptTemplateT],
cls: Type[_StringMessagePromptTemplateT],
template: str,
template_format: str = "f-string",
**kwargs: Any,
) -> MessagePromptTemplateT:
) -> _StringMessagePromptTemplateT:
"""Create a class from a string template.
Args:
@@ -154,11 +159,11 @@ class BaseStringMessagePromptTemplate(BaseMessagePromptTemplate, ABC):
@classmethod
def from_template_file(
cls: Type[MessagePromptTemplateT],
cls: Type[_StringMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> MessagePromptTemplateT:
) -> _StringMessagePromptTemplateT:
"""Create a class from a template file.
Args:
@@ -226,9 +231,140 @@ class ChatMessagePromptTemplate(BaseStringMessagePromptTemplate):
)
class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
_StringImageMessagePromptTemplateT = TypeVar(
"_StringImageMessagePromptTemplateT", bound="_StringImageMessagePromptTemplate"
)
class _TextTemplateParam(TypedDict, total=False):
text: Union[str, Dict]
class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, Dict]
class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
prompt: Union[
StringPromptTemplate, List[Union[StringPromptTemplate, ImagePromptTemplate]]
]
"""Prompt template."""
additional_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the prompt template."""
_msg_class: Type[BaseMessage]
@classmethod
def from_template(
cls: Type[_StringImageMessagePromptTemplateT],
template: Union[str, List[Union[str, _TextTemplateParam, _ImageTemplateParam]]],
template_format: str = "f-string",
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a string template.
Args:
template: a template.
template_format: format of the template.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
if isinstance(template, str):
prompt: Union[StringPromptTemplate, List] = PromptTemplate.from_template(
template, template_format=template_format
)
return cls(prompt=prompt, **kwargs)
elif isinstance(template, list):
prompt = []
for tmpl in template:
if isinstance(tmpl, str) or isinstance(tmpl, dict) and "text" in tmpl:
if isinstance(tmpl, str):
text: str = tmpl
else:
text = cast(_TextTemplateParam, tmpl)["text"] # type: ignore[assignment] # noqa: E501
prompt.append(
PromptTemplate.from_template(
text, template_format=template_format
)
)
elif isinstance(tmpl, dict) and "image_url" in tmpl:
img_template = cast(_ImageTemplateParam, tmpl)["image_url"]
if isinstance(img_template, str):
vars = get_template_variables(img_template, "f-string")
if vars:
if len(vars) > 1:
raise ValueError
variable_name = vars[0]
img_template = {}
else:
variable_name = None
img_template = {"url": img_template}
img_template_obj = ImagePromptTemplate(
variable_name=variable_name, template=img_template
)
elif isinstance(img_template, dict):
img_template = dict(img_template)
variable_name = img_template.pop("variable_name", None)
img_template_obj = ImagePromptTemplate(
variable_name=variable_name, template=img_template
)
else:
raise ValueError()
prompt.append(img_template_obj)
else:
raise ValueError()
return cls(prompt=prompt, **kwargs)
else:
raise ValueError()
@classmethod
def from_template_file(
cls: Type[_StringImageMessagePromptTemplateT],
template_file: Union[str, Path],
input_variables: List[str],
**kwargs: Any,
) -> _StringImageMessagePromptTemplateT:
"""Create a class from a template file.
Args:
template_file: path to a template file. String or Path.
input_variables: list of input variables.
**kwargs: keyword arguments to pass to the constructor.
Returns:
A new instance of this class.
"""
with open(str(template_file), "r") as f:
template = f.read()
return cls.from_template(template, input_variables=input_variables, **kwargs)
def format_messages(self, **kwargs: Any) -> List[BaseMessage]:
"""Format messages from kwargs.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
List of BaseMessages.
"""
return [self.format(**kwargs)]
@property
def input_variables(self) -> List[str]:
"""
Input variables for this prompt template.
Returns:
List of input variable names.
"""
prompts = self.prompt if isinstance(self.prompt, list) else [self.prompt]
input_variables = [iv for prompt in prompts for iv in prompt.input_variables]
return input_variables
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
@@ -238,42 +374,44 @@ class HumanMessagePromptTemplate(BaseStringMessagePromptTemplate):
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return HumanMessage(content=text, additional_kwargs=self.additional_kwargs)
if isinstance(self.prompt, StringPromptTemplate):
text = self.prompt.format(**kwargs)
return self._msg_class(
content=text, additional_kwargs=self.additional_kwargs
)
else:
content = []
for prompt in self.prompt:
inputs = {var: kwargs[var] for var in prompt.input_variables}
if isinstance(prompt, StringPromptTemplate):
formatted: Union[str, ImageURL] = prompt.format(**inputs)
content.append({"type": "text", "text": formatted})
elif isinstance(prompt, ImagePromptTemplate):
formatted = prompt.format(**inputs)
content.append({"type": "image_url", "image_url": formatted})
return self._msg_class(
content=content, additional_kwargs=self.additional_kwargs
)
class AIMessagePromptTemplate(BaseStringMessagePromptTemplate):
class HumanMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""
_msg_class: Type[BaseMessage] = HumanMessage
class AIMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""AI message prompt template. This is a message sent from the AI."""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return AIMessage(content=text, additional_kwargs=self.additional_kwargs)
_msg_class: Type[BaseMessage] = AIMessage
class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
class SystemMessagePromptTemplate(_StringImageMessagePromptTemplate):
"""System message prompt template.
This is a message that is not sent to the user.
"""
def format(self, **kwargs: Any) -> BaseMessage:
"""Format the prompt template.
Args:
**kwargs: Keyword arguments to use for formatting.
Returns:
Formatted message.
"""
text = self.prompt.format(**kwargs)
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
_msg_class: Type[BaseMessage] = SystemMessage
class ChatPromptValue(PromptValue):
@@ -677,7 +815,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
def _create_template_from_message_type(
message_type: str, template: str
message_type: str, template: Union[str, list]
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
@@ -693,9 +831,9 @@ def _create_template_from_message_type(
template
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(template)
message = AIMessagePromptTemplate.from_template(cast(str, template))
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(template)
message = SystemMessagePromptTemplate.from_template(cast(str, template))
else:
raise ValueError(
f"Unexpected message type: {message_type}. Use one of 'human',"

View File

@@ -0,0 +1,68 @@
from typing import Any, Union
from langchain.prompts.base import ImagePromptValue
from langchain.pydantic_v1 import Field
from langchain.schema import BasePromptTemplate, PromptValue
from langchain.types.image import ImageURL
from langchain.utils.image import image_to_data_url
class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
"""An image prompt template for a language model."""
variable_name: Union[str, None] = None
"""Name of variable to use as messages."""
template: dict = Field(default_factory=dict)
""""""
def __init__(self, **kwargs: Any) -> None:
if "variable_name" in kwargs:
# protected var names for formatting
if kwargs["variable_name"] in ("url", "path", "detail"):
raise ValueError("")
if "input_variables" not in kwargs:
kwargs["input_variables"] = (
[kwargs["variable_name"]] if kwargs["variable_name"] else []
)
super().__init__(**kwargs)
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
return "image-prompt"
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
return ImagePromptValue(image_url=self.format(**kwargs))
def format(
self,
**kwargs: Any,
) -> ImageURL:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
var = kwargs.get(self.variable_name, {}) if self.variable_name else {}
if isinstance(var, str):
var = {"url": var}
var = {**self.template, **var}
url = kwargs.get("url") or var.get("url")
path = kwargs.get("path") or var.get("path")
detail = kwargs.get("detail") or var.get("detail")
output: ImageURL = {"url": url or image_to_data_url(path)}
if detail:
output["detail"] = detail
return output

View File

@@ -3,7 +3,18 @@ from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Mapping,
Optional,
Type,
TypeVar,
Union,
)
import yaml
@@ -13,8 +24,12 @@ from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import RunnableConfig, RunnableSerializable
FormatOutputType = TypeVar("FormatOutputType")
class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
class BasePromptTemplate(
RunnableSerializable[Dict, PromptValue], Generic[FormatOutputType], ABC
):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
@@ -111,7 +126,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
def format(self, **kwargs: Any) -> FormatOutputType:
"""Format the prompt with the inputs.
Args:
@@ -179,7 +194,7 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
raise ValueError(f"{save_path} must be json or yaml")
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
def format_document(doc: Document, prompt: BasePromptTemplate[str]) -> str:
"""Format a document into a string based on a prompt template.
First, this pulls information from the document from two sources:

View File

@@ -0,0 +1,11 @@
from typing import Literal
from typing_extensions import TypedDict
class ImageURL(TypedDict, total=False):
detail: Literal["auto", "low", "high"]
"""Specifies the detail level of the image."""
url: str
"""Either a URL of the image or the base64 encoded image data."""

View File

@@ -0,0 +1,14 @@
import base64
import mimetypes
def encode_image(image_path: str) -> str:
"""Get base64 string from image URI."""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def image_to_data_url(image_path: str) -> str:
encoding = encode_image(image_path)
mime_type = mimetypes.guess_type(image_path)[0]
return f"data:{mime_type};base64,{encoding}"

View File

@@ -1,6 +1,5 @@
from __future__ import annotations
import base64
import logging
import uuid
from typing import (
@@ -21,6 +20,7 @@ from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.utils import xor_args
from langchain.utils.image import encode_image
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
@@ -161,11 +161,6 @@ class Chroma(VectorStore):
**kwargs,
)
def encode_image(self, uri: str) -> str:
"""Get base64 string from image URI."""
with open(uri, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def add_images(
self,
uris: List[str],
@@ -184,7 +179,7 @@ class Chroma(VectorStore):
List[str]: List of IDs of the added images.
"""
# Map from uris to b64 encoded strings
b64_texts = [self.encode_image(uri=uri) for uri in uris]
b64_texts = [encode_image(uri) for uri in uris]
# Populate IDs
if ids is None:
ids = [str(uuid.uuid1()) for _ in uris]

View File

@@ -16,7 +16,7 @@ git grep '^from langchain' langchain/callbacks | grep -vE 'from langchain.(pydan
# TODO: it's probably not amazing so that so many other modules depend on `langchain.utilities`, because there can be a lot of imports there
git grep '^from langchain' langchain/utilities | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/storage | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|storage|utilities)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api)' && errors=$((errors+1))
git grep '^from langchain' langchain/prompts | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|types)' && errors=$((errors+1))
git grep '^from langchain' langchain/output_parsers | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|_api|output_parsers)' && errors=$((errors+1))
git grep '^from langchain' langchain/llms | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|prompts|llms|utilities|globals)' && errors=$((errors+1))
git grep '^from langchain' langchain/chat_models | grep -vE 'from langchain.(pydantic_v1|utils|schema|load|callbacks|env|llms|prompts|adapters|chat_models|utilities|globals)' && errors=$((errors+1))