mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-05 04:55:14 +00:00
Bagatur/docs smith context (#13139)
This commit is contained in:
@@ -108,6 +108,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from unstructured.partition.pdf import partition_pdf\n",
|
||||
"\n",
|
||||
"# Extract images, tables, and chunk text\n",
|
||||
"raw_pdf_elements = partition_pdf(\n",
|
||||
" filename=path + \"wildfire_stats.pdf\",\n",
|
||||
@@ -189,8 +190,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Apply to text\n",
|
||||
"# Typically this is reccomended only if you have large text chunks \n",
|
||||
"text_summaries = texts # Skip it\n",
|
||||
"# Typically this is reccomended only if you have large text chunks\n",
|
||||
"text_summaries = texts # Skip it\n",
|
||||
"\n",
|
||||
"# Apply to tables\n",
|
||||
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})"
|
||||
@@ -228,26 +229,25 @@
|
||||
"from langchain.chat_models import ChatOpenAI\n",
|
||||
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
|
||||
"\n",
|
||||
"def encode_image(image_path):\n",
|
||||
" ''' Getting the base64 string '''\n",
|
||||
" with open(image_path, \"rb\") as image_file:\n",
|
||||
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
|
||||
"\n",
|
||||
"def image_summarize(img_base64,prompt):\n",
|
||||
" ''' Image summary '''\n",
|
||||
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\",\n",
|
||||
" max_tokens=1024)\n",
|
||||
" \n",
|
||||
"def encode_image(image_path):\n",
|
||||
" \"\"\"Getting the base64 string\"\"\"\n",
|
||||
" with open(image_path, \"rb\") as image_file:\n",
|
||||
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def image_summarize(img_base64, prompt):\n",
|
||||
" \"\"\"Image summary\"\"\"\n",
|
||||
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
||||
"\n",
|
||||
" msg = chat.invoke(\n",
|
||||
" [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=[\n",
|
||||
" {\"type\": \"text\", \"text\":prompt},\n",
|
||||
" {\"type\": \"text\", \"text\": prompt},\n",
|
||||
" {\n",
|
||||
" \"type\": \"image_url\",\n",
|
||||
" \"image_url\": {\n",
|
||||
" \"url\": f\"data:image/jpeg;base64,{img_base64}\"\n",
|
||||
" },\n",
|
||||
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
@@ -255,6 +255,7 @@
|
||||
" )\n",
|
||||
" return msg.content\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Store base64 encoded images\n",
|
||||
"img_base64_list = []\n",
|
||||
"\n",
|
||||
@@ -262,15 +263,15 @@
|
||||
"image_summaries = []\n",
|
||||
"\n",
|
||||
"# Prompt\n",
|
||||
"prompt = \"Describe the image in detail. Be specific about graphs, such as bar plots.\" \n",
|
||||
"prompt = \"Describe the image in detail. Be specific about graphs, such as bar plots.\"\n",
|
||||
"\n",
|
||||
"# Read images, encode to base64 strings\n",
|
||||
"for img_file in sorted(os.listdir(path)):\n",
|
||||
" if img_file.endswith('.jpg'):\n",
|
||||
" if img_file.endswith(\".jpg\"):\n",
|
||||
" img_path = os.path.join(path, img_file)\n",
|
||||
" base64_image = encode_image(img_path)\n",
|
||||
" img_base64_list.append(base64_image)\n",
|
||||
" image_summaries.append(image_summarize(base64_image,prompt))"
|
||||
" image_summaries.append(image_summarize(base64_image, prompt))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -295,14 +296,15 @@
|
||||
"source": [
|
||||
"from IPython.display import display, HTML\n",
|
||||
"\n",
|
||||
"def plt_img_base64(img_base64):\n",
|
||||
"\n",
|
||||
"def plt_img_base64(img_base64):\n",
|
||||
" # Create an HTML img tag with the base64 string as the source\n",
|
||||
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # Display the image by rendering the HTML\n",
|
||||
" display(HTML(image_html))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"plt_img_base64(img_base64_list[1])"
|
||||
]
|
||||
},
|
||||
@@ -352,8 +354,9 @@
|
||||
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
|
||||
"\n",
|
||||
"# The vectorstore to use to index the child chunks\n",
|
||||
"vectorstore = Chroma(collection_name=\"multi_modal_rag\", \n",
|
||||
" embedding_function=OpenAIEmbeddings())\n",
|
||||
"vectorstore = Chroma(\n",
|
||||
" collection_name=\"multi_modal_rag\", embedding_function=OpenAIEmbeddings()\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# The storage layer for the parent documents\n",
|
||||
"store = InMemoryStore()\n",
|
||||
@@ -478,8 +481,10 @@
|
||||
],
|
||||
"source": [
|
||||
"from base64 import b64decode\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def split_image_text_types(docs):\n",
|
||||
" ''' Split base64-encoded images and texts '''\n",
|
||||
" \"\"\"Split base64-encoded images and texts\"\"\"\n",
|
||||
" b64 = []\n",
|
||||
" text = []\n",
|
||||
" for doc in docs:\n",
|
||||
@@ -488,10 +493,9 @@
|
||||
" b64.append(doc)\n",
|
||||
" except Exception as e:\n",
|
||||
" text.append(doc)\n",
|
||||
" return {\n",
|
||||
" \"images\": b64,\n",
|
||||
" \"texts\": text\n",
|
||||
" }\n",
|
||||
" return {\"images\": b64, \"texts\": text}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"docs_by_type = split_image_text_types(docs)\n",
|
||||
"plt_img_base64(docs_by_type[\"images\"][0])"
|
||||
]
|
||||
@@ -522,27 +526,40 @@
|
||||
"from operator import itemgetter\n",
|
||||
"from langchain.schema.runnable import RunnablePassthrough, RunnableLambda\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def prompt_func(dict):\n",
|
||||
" format_texts = \"\\n\".join(dict[\"context\"][\"texts\"])\n",
|
||||
" return [\n",
|
||||
" HumanMessage(\n",
|
||||
" content=[\n",
|
||||
" {\"type\": \"text\", \"text\": f\"\"\"Answer the question based only on the following context, which can include text, tables, and the below image:\n",
|
||||
" {\n",
|
||||
" \"type\": \"text\",\n",
|
||||
" \"text\": f\"\"\"Answer the question based only on the following context, which can include text, tables, and the below image:\n",
|
||||
"Question: {dict[\"question\"]}\n",
|
||||
"\n",
|
||||
"Text and tables:\n",
|
||||
"{format_texts}\n",
|
||||
"\"\"\"},\n",
|
||||
" {\"type\": \"image_url\", \"image_url\": {\"url\": f\"data:image/jpeg;base64,{dict['context']['images'][0]}\"}},\n",
|
||||
"\"\"\",\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"type\": \"image_url\",\n",
|
||||
" \"image_url\": {\n",
|
||||
" \"url\": f\"data:image/jpeg;base64,{dict['context']['images'][0]}\"\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
|
||||
"\n",
|
||||
"# RAG pipeline\n",
|
||||
"chain = (\n",
|
||||
" {\"context\": retriever | RunnableLambda(split_image_text_types), \"question\": RunnablePassthrough()}\n",
|
||||
" {\n",
|
||||
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
|
||||
" \"question\": RunnablePassthrough(),\n",
|
||||
" }\n",
|
||||
" | RunnableLambda(prompt_func)\n",
|
||||
" | model\n",
|
||||
" | StrOutputParser()\n",
|
||||
@@ -574,9 +591,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chain.invoke(\n",
|
||||
" \"What is the change in wild fires from 1993 to 2022?\"\n",
|
||||
")"
|
||||
"chain.invoke(\"What is the change in wild fires from 1993 to 2022?\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
Reference in New Issue
Block a user