Compare commits

...

2 Commits

Author SHA1 Message Date
Lance Martin
5c0dc9d154 fmt 2023-12-06 12:01:20 -08:00
Lance Martin
7f4ddbb0ee Update to use multi-modal prompt template 2023-12-06 11:52:18 -08:00
2 changed files with 92 additions and 61 deletions

View File

@@ -112,7 +112,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "c59df23c-86f7-4e5d-8b8c-de92a92f6637",
"id": "6d4472dc-f89a-4326-a6a7-60ce5c52e553",
"metadata": {},
"outputs": [],
"source": [
@@ -269,8 +269,7 @@
"source": [
"import base64\n",
"import os\n",
"\n",
"from langchain.schema.messages import HumanMessage\n",
"from langchain_core.prompts.chat import ChatPromptTemplate\n",
"\n",
"\n",
"def encode_image(image_path):\n",
@@ -278,25 +277,38 @@
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def image_summarize(img_base64, prompt):\n",
"def image_summarize(img_base64):\n",
" \"\"\"Make image summary\"\"\"\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" msg = chat.invoke(\n",
" # Prompt\n",
" prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
" These summaries will be embedded and used to retrieve the raw image. \\\n",
" Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
"\n",
" summarization_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": prompt},\n",
" (\n",
" \"system\",\n",
" \"You are an analyst tasked with summarizing images. \\n\"\n",
" \"You will be give an image to summarize.\\n\",\n",
" ),\n",
" (\n",
" \"human\",\n",
" [\n",
" {\"type\": \"text\", \"text\": \"{prompt}\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" \"image_url\": \"data:image/jpeg;base64,{img}\",\n",
" },\n",
" ]\n",
" )\n",
" ],\n",
" ),\n",
" ]\n",
" )\n",
" return msg.content\n",
"\n",
" llm = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
" chain = summarization_prompt | llm\n",
" summary = chain.invoke({\"prompt\": prompt, \"img\": img_base64})\n",
" return summary\n",
"\n",
"\n",
"def generate_img_summaries(path):\n",
@@ -322,7 +334,7 @@
" img_path = os.path.join(path, img_file)\n",
" base64_image = encode_image(img_path)\n",
" img_base64_list.append(base64_image)\n",
" image_summaries.append(image_summarize(base64_image, prompt))\n",
" image_summaries.append(image_summarize(base64_image))\n",
"\n",
" return img_base64_list, image_summaries\n",
"\n",
@@ -515,38 +527,45 @@
" texts.append(doc)\n",
" return {\"images\": b64_images, \"texts\": texts}\n",
"\n",
"def img_prompt_func(data_dict, num_images=2):\n",
" \"\"\"\n",
" GPT-4V prompt for image analysis.\n",
" \"\"\"\n",
"\n",
"def img_prompt_func(data_dict):\n",
" \"\"\"\n",
" Join the context into a single string\n",
" \"\"\"\n",
" # Text\n",
" formatted_texts = \"\\n\".join(data_dict[\"context\"][\"texts\"])\n",
" messages = []\n",
"\n",
" # Adding image(s) to the messages if present\n",
" if data_dict[\"context\"][\"images\"]:\n",
" for image in data_dict[\"context\"][\"images\"]:\n",
" image_message = {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n",
" }\n",
" messages.append(image_message)\n",
"\n",
" # Adding the text for analysis\n",
" text_message = {\n",
" \"type\": \"text\",\n",
" \"text\": (\n",
" \"You are financial analyst tasking with providing investment advice.\\n\"\n",
" \"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\\n\"\n",
" \"Use this information to provide investment advice related to the user question. \\n\"\n",
" f\"User-provided question: {data_dict['question']}\\n\\n\"\n",
" \"Text and / or tables:\\n\"\n",
" f\"{formatted_texts}\"\n",
" # Base template\n",
" template_messages = [\n",
" (\n",
" \"system\",\n",
" \"You are an analyst tasked with answering questions about visual content. \\n\"\n",
" \"You will be given a set of image(s) from a slide deck / presentation.\\n\",\n",
" ),\n",
" }\n",
" messages.append(text_message)\n",
" return [HumanMessage(content=messages)]\n",
" (\n",
" \"human\",\n",
" [\n",
" {\"type\": \"text\", \"text\": \"Answer the question using the images. Question: {question}\"},\n",
" {\"type\": \"text\", \"text\": f\"Text and / or tables: {formatted_texts}\"},\n",
" )\n",
" ]\n",
"\n",
" # Add images\n",
" images = data_dict[\"context\"][\"images\"]\n",
" for i in range(min(num_images, len(images))):\n",
" image_message = {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{images[i]}\"},\n",
" }\n",
" template_messages[1][1].append(image_message)\n",
"\n",
" # Format\n",
" rag_prompt = ChatPromptTemplate.from_messages(template_messages)\n",
" rag_prompt_formatted = rag_prompt.format_messages(\n",
" question=data_dict[\"question\"],\n",
" )\n",
"\n",
" return rag_prompt_formatted\n",
"\n",
"def multi_modal_rag_chain(retriever):\n",
" \"\"\"\n",

View File

@@ -5,10 +5,10 @@ from pathlib import Path
from langchain.chat_models import ChatOpenAI
from langchain.pydantic_v1 import BaseModel
from langchain.schema.document import Document
from langchain.schema.messages import HumanMessage
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.vectorstores import Chroma
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_experimental.open_clip import OpenCLIPEmbeddings
from PIL import Image
@@ -53,25 +53,37 @@ def img_prompt_func(data_dict, num_images=2):
:param num_images: Number of images to include in the prompt.
:return: A list containing message objects for each image and the text prompt.
"""
messages = []
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"][:num_images]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
text_message = {
"type": "text",
"text": (
"You are an analyst tasked with answering questions about visual content.\n"
"You will be give a set of image(s) from a slide deck / presentation.\n"
"Use this information to answer the user question. \n"
f"User-provided question: {data_dict['question']}\n\n"
# Base template
template_messages = [
("system", "Answer questions use images. \n"),
(
"human",
[
{
"type": "text",
"text": "Answer the question with the given images. Question: {question}",
}
],
),
}
messages.append(text_message)
return [HumanMessage(content=messages)]
]
# Add images
images = data_dict["context"]["images"]
for i in range(min(num_images, len(images))):
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{images[i]}"},
}
template_messages[1][1].append(image_message)
# Format
rag_prompt = ChatPromptTemplate.from_messages(template_messages)
rag_prompt_formatted = rag_prompt.format_messages(
question=data_dict["question"],
)
return rag_prompt_formatted
def multi_modal_rag_chain(retriever):