notebook fmt (#12498)

This commit is contained in:
Bagatur
2023-10-29 15:50:09 -07:00
committed by GitHub
parent 56cc5b847c
commit 2424fff3f1
342 changed files with 8261 additions and 6796 deletions

View File

@@ -98,22 +98,24 @@
"from unstructured.partition.pdf import partition_pdf\n",
"\n",
"# Get elements\n",
"raw_pdf_elements = partition_pdf(filename=path+\"LLaVA.pdf\",\n",
" # Using pdf format to find embedded image blocks\n",
" extract_images_in_pdf=True,\n",
" # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles\n",
" # Titles are any sub-section of the document \n",
" infer_table_structure=True, \n",
" # Post processing to aggregate text once we have the title \n",
" chunking_strategy=\"by_title\",\n",
" # Chunking params to aggregate text blocks\n",
" # Attempt to create a new chunk 3800 chars\n",
" # Attempt to keep chunks > 2000 chars \n",
" # Hard max on chunks\n",
" max_characters=4000, \n",
" new_after_n_chars=3800, \n",
" combine_text_under_n_chars=2000,\n",
" image_output_dir_path=path)"
"raw_pdf_elements = partition_pdf(\n",
" filename=path + \"LLaVA.pdf\",\n",
" # Using pdf format to find embedded image blocks\n",
" extract_images_in_pdf=True,\n",
" # Use layout model (YOLOX) to get bounding boxes (for tables) and find titles\n",
" # Titles are any sub-section of the document\n",
" infer_table_structure=True,\n",
" # Post processing to aggregate text once we have the title\n",
" chunking_strategy=\"by_title\",\n",
" # Chunking params to aggregate text blocks\n",
" # Attempt to create a new chunk 3800 chars\n",
" # Attempt to keep chunks > 2000 chars\n",
" # Hard max on chunks\n",
" max_characters=4000,\n",
" new_after_n_chars=3800,\n",
" combine_text_under_n_chars=2000,\n",
" image_output_dir_path=path,\n",
")"
]
},
{
@@ -170,6 +172,7 @@
" type: str\n",
" text: Any\n",
"\n",
"\n",
"# Categorize by type\n",
"categorized_elements = []\n",
"for element in raw_pdf_elements:\n",
@@ -220,14 +223,14 @@
"metadata": {},
"outputs": [],
"source": [
"# Prompt \n",
"prompt_text=\"\"\"You are an assistant tasked with summarizing tables and text. \\ \n",
"# Prompt\n",
"prompt_text = \"\"\"You are an assistant tasked with summarizing tables and text. \\ \n",
"Give a concise summary of the table or text. Table or text chunk: {element} \"\"\"\n",
"prompt = ChatPromptTemplate.from_template(prompt_text) \n",
"prompt = ChatPromptTemplate.from_template(prompt_text)\n",
"\n",
"# Summary chain \n",
"model = ChatOpenAI(temperature=0,model=\"gpt-4\")\n",
"summarize_chain = {\"element\": lambda x:x} | prompt | model | StrOutputParser()"
"# Summary chain\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()"
]
},
{
@@ -342,11 +345,11 @@
"# Read each file and store its content in a list\n",
"img_summaries = []\n",
"for file_path in file_paths:\n",
" with open(file_path, 'r') as file:\n",
" with open(file_path, \"r\") as file:\n",
" img_summaries.append(file.read())\n",
"\n",
"# Remove any logging prior to summary\n",
"logging_header=\"clip_model_load: total allocated memory: 201.27 MB\\n\\n\"\n",
"logging_header = \"clip_model_load: total allocated memory: 201.27 MB\\n\\n\"\n",
"cleaned_img_summary = [s.split(logging_header, 1)[1].strip() for s in img_summaries]"
]
},
@@ -375,10 +378,7 @@
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(\n",
" collection_name=\"summaries\",\n",
" embedding_function=OpenAIEmbeddings()\n",
")\n",
"vectorstore = Chroma(collection_name=\"summaries\", embedding_function=OpenAIEmbeddings())\n",
"\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
@@ -386,20 +386,26 @@
"\n",
"# The retriever (empty to start)\n",
"retriever = MultiVectorRetriever(\n",
" vectorstore=vectorstore, \n",
" docstore=store, \n",
" vectorstore=vectorstore,\n",
" docstore=store,\n",
" id_key=id_key,\n",
")\n",
"\n",
"# Add texts\n",
"doc_ids = [str(uuid.uuid4()) for _ in texts]\n",
"summary_texts = [Document(page_content=s,metadata={id_key: doc_ids[i]}) for i, s in enumerate(text_summaries)]\n",
"summary_texts = [\n",
" Document(page_content=s, metadata={id_key: doc_ids[i]})\n",
" for i, s in enumerate(text_summaries)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_texts)\n",
"retriever.docstore.mset(list(zip(doc_ids, texts)))\n",
"\n",
"# Add tables\n",
"table_ids = [str(uuid.uuid4()) for _ in tables]\n",
"summary_tables = [Document(page_content=s,metadata={id_key: table_ids[i]}) for i, s in enumerate(table_summaries)]\n",
"summary_tables = [\n",
" Document(page_content=s, metadata={id_key: table_ids[i]})\n",
" for i, s in enumerate(table_summaries)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_tables)\n",
"retriever.docstore.mset(list(zip(table_ids, tables)))"
]
@@ -423,9 +429,12 @@
"source": [
"# Add image summaries\n",
"img_ids = [str(uuid.uuid4()) for _ in cleaned_img_summary]\n",
"summary_img = [Document(page_content=s,metadata={id_key: img_ids[i]}) for i, s in enumerate(cleaned_img_summary)]\n",
"summary_img = [\n",
" Document(page_content=s, metadata={id_key: img_ids[i]})\n",
" for i, s in enumerate(cleaned_img_summary)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_img)\n",
"retriever.docstore.mset(list(zip(img_ids, cleaned_img_summary))) "
"retriever.docstore.mset(list(zip(img_ids, cleaned_img_summary)))"
]
},
{
@@ -449,10 +458,19 @@
"source": [
"# Add images\n",
"img_ids = [str(uuid.uuid4()) for _ in cleaned_img_summary]\n",
"summary_img = [Document(page_content=s,metadata={id_key: img_ids[i]}) for i, s in enumerate(cleaned_img_summary)]\n",
"summary_img = [\n",
" Document(page_content=s, metadata={id_key: img_ids[i]})\n",
" for i, s in enumerate(cleaned_img_summary)\n",
"]\n",
"retriever.vectorstore.add_documents(summary_img)\n",
"### Fetch images\n",
"retriever.docstore.mset(list(zip(img_ids, ### image ### ))) "
"retriever.docstore.mset(\n",
" list(\n",
" zip(\n",
" img_ids,\n",
" )\n",
" )\n",
")"
]
},
{
@@ -542,7 +560,9 @@
],
"source": [
"# We can retrieve this table\n",
"retriever.get_relevant_documents(\"What are results for LLaMA across across domains / subjects?\")[1]"
"retriever.get_relevant_documents(\n",
" \"What are results for LLaMA across across domains / subjects?\"\n",
")[1]"
]
},
{
@@ -592,7 +612,9 @@
}
],
"source": [
"retriever.get_relevant_documents(\"Images / figures with playful and creative examples\")[1]"
"retriever.get_relevant_documents(\"Images / figures with playful and creative examples\")[\n",
" 1\n",
"]"
]
},
{
@@ -633,15 +655,15 @@
"prompt = ChatPromptTemplate.from_template(template)\n",
"\n",
"# Option 1: LLM\n",
"model = ChatOpenAI(temperature=0,model=\"gpt-4\")\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"# Option 2: Multi-modal LLM\n",
"# model = GPT4-V or LLaVA\n",
"\n",
"# RAG pipeline\n",
"chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()} \n",
" | prompt \n",
" | model \n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
" | prompt\n",
" | model\n",
" | StrOutputParser()\n",
")"
]
@@ -664,7 +686,9 @@
}
],
"source": [
"chain.invoke(\"What is the performance of LLaVa across across multiple image domains / subjects?\")"
"chain.invoke(\n",
" \"What is the performance of LLaVa across across multiple image domains / subjects?\"\n",
")"
]
},
{
@@ -713,7 +737,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.1"
}
},
"nbformat": 4,