Bagatur/docs smith context (#13139)

This commit is contained in:
Bagatur
2023-11-09 10:22:49 -08:00
committed by GitHub
parent 58da6e0d47
commit 8b2a82b5ce
18 changed files with 386 additions and 283 deletions

View File

@@ -108,6 +108,7 @@
"outputs": [],
"source": [
"from unstructured.partition.pdf import partition_pdf\n",
"\n",
"# Extract images, tables, and chunk text\n",
"raw_pdf_elements = partition_pdf(\n",
" filename=path + \"wildfire_stats.pdf\",\n",
@@ -189,8 +190,8 @@
"outputs": [],
"source": [
"# Apply to text\n",
"# Typically this is reccomended only if you have large text chunks \n",
"text_summaries = texts # Skip it\n",
"# Typically this is reccomended only if you have large text chunks\n",
"text_summaries = texts # Skip it\n",
"\n",
"# Apply to tables\n",
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})"
@@ -228,26 +229,25 @@
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"\n",
"def encode_image(image_path):\n",
" ''' Getting the base64 string '''\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
"\n",
"def image_summarize(img_base64,prompt):\n",
" ''' Image summary '''\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\",\n",
" max_tokens=1024)\n",
" \n",
"def encode_image(image_path):\n",
" \"\"\"Getting the base64 string\"\"\"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def image_summarize(img_base64, prompt):\n",
" \"\"\"Image summary\"\"\"\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" msg = chat.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\":prompt},\n",
" {\"type\": \"text\", \"text\": prompt},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{img_base64}\"\n",
" },\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" },\n",
" ]\n",
" )\n",
@@ -255,6 +255,7 @@
" )\n",
" return msg.content\n",
"\n",
"\n",
"# Store base64 encoded images\n",
"img_base64_list = []\n",
"\n",
@@ -262,15 +263,15 @@
"image_summaries = []\n",
"\n",
"# Prompt\n",
"prompt = \"Describe the image in detail. Be specific about graphs, such as bar plots.\" \n",
"prompt = \"Describe the image in detail. Be specific about graphs, such as bar plots.\"\n",
"\n",
"# Read images, encode to base64 strings\n",
"for img_file in sorted(os.listdir(path)):\n",
" if img_file.endswith('.jpg'):\n",
" if img_file.endswith(\".jpg\"):\n",
" img_path = os.path.join(path, img_file)\n",
" base64_image = encode_image(img_path)\n",
" img_base64_list.append(base64_image)\n",
" image_summaries.append(image_summarize(base64_image,prompt))"
" image_summaries.append(image_summarize(base64_image, prompt))"
]
},
{
@@ -295,14 +296,15 @@
"source": [
"from IPython.display import display, HTML\n",
"\n",
"def plt_img_base64(img_base64):\n",
"\n",
"def plt_img_base64(img_base64):\n",
" # Create an HTML img tag with the base64 string as the source\n",
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
" \n",
"\n",
" # Display the image by rendering the HTML\n",
" display(HTML(image_html))\n",
"\n",
"\n",
"plt_img_base64(img_base64_list[1])"
]
},
@@ -352,8 +354,9 @@
"from langchain.retrievers.multi_vector import MultiVectorRetriever\n",
"\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(collection_name=\"multi_modal_rag\", \n",
" embedding_function=OpenAIEmbeddings())\n",
"vectorstore = Chroma(\n",
" collection_name=\"multi_modal_rag\", embedding_function=OpenAIEmbeddings()\n",
")\n",
"\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
@@ -478,8 +481,10 @@
],
"source": [
"from base64 import b64decode\n",
"\n",
"\n",
"def split_image_text_types(docs):\n",
" ''' Split base64-encoded images and texts '''\n",
" \"\"\"Split base64-encoded images and texts\"\"\"\n",
" b64 = []\n",
" text = []\n",
" for doc in docs:\n",
@@ -488,10 +493,9 @@
" b64.append(doc)\n",
" except Exception as e:\n",
" text.append(doc)\n",
" return {\n",
" \"images\": b64,\n",
" \"texts\": text\n",
" }\n",
" return {\"images\": b64, \"texts\": text}\n",
"\n",
"\n",
"docs_by_type = split_image_text_types(docs)\n",
"plt_img_base64(docs_by_type[\"images\"][0])"
]
@@ -522,27 +526,40 @@
"from operator import itemgetter\n",
"from langchain.schema.runnable import RunnablePassthrough, RunnableLambda\n",
"\n",
"\n",
"def prompt_func(dict):\n",
" format_texts = \"\\n\".join(dict[\"context\"][\"texts\"])\n",
" return [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": f\"\"\"Answer the question based only on the following context, which can include text, tables, and the below image:\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": f\"\"\"Answer the question based only on the following context, which can include text, tables, and the below image:\n",
"Question: {dict[\"question\"]}\n",
"\n",
"Text and tables:\n",
"{format_texts}\n",
"\"\"\"},\n",
" {\"type\": \"image_url\", \"image_url\": {\"url\": f\"data:image/jpeg;base64,{dict['context']['images'][0]}\"}},\n",
"\"\"\",\n",
" },\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{dict['context']['images'][0]}\"\n",
" },\n",
" },\n",
" ]\n",
" )\n",
" ]\n",
"\n",
"\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
"# RAG pipeline\n",
"chain = (\n",
" {\"context\": retriever | RunnableLambda(split_image_text_types), \"question\": RunnablePassthrough()}\n",
" {\n",
" \"context\": retriever | RunnableLambda(split_image_text_types),\n",
" \"question\": RunnablePassthrough(),\n",
" }\n",
" | RunnableLambda(prompt_func)\n",
" | model\n",
" | StrOutputParser()\n",
@@ -574,9 +591,7 @@
}
],
"source": [
"chain.invoke(\n",
" \"What is the change in wild fires from 1993 to 2022?\"\n",
")"
"chain.invoke(\"What is the change in wild fires from 1993 to 2022?\")"
]
},
{

View File

@@ -46,6 +46,7 @@
"# Pydantic is an easy way to define a schema\n",
"class Person(BaseModel):\n",
" \"\"\"Information about people to extract.\"\"\"\n",
"\n",
" name: str\n",
" age: Optional[int] = None"
]
@@ -91,6 +92,7 @@
"# Let's define another element\n",
"class Class(BaseModel):\n",
" \"\"\"Information about classes to extract.\"\"\"\n",
"\n",
" teacher: str\n",
" students: List[str]"
]

View File

@@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install \"openai>=1\" \"langchain>=0.0.331rc2\" matplotlib pillow "
"! pip install \"openai>=1\" \"langchain>=0.0.331rc2\" matplotlib pillow"
]
},
{
@@ -47,22 +47,24 @@
"from PIL import Image\n",
"from IPython.display import display, HTML\n",
"\n",
"\n",
"def encode_image(image_path):\n",
" ''' Getting the base64 string '''\n",
" \n",
" \"\"\"Getting the base64 string\"\"\"\n",
"\n",
" with open(image_path, \"rb\") as image_file:\n",
" return base64.b64encode(image_file.read()).decode('utf-8')\n",
" return base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
"\n",
"def plt_img_base64(img_base64):\n",
" ''' Display the base64 image '''\n",
" \"\"\"Display the base64 image\"\"\"\n",
"\n",
" # Create an HTML img tag with the base64 string as the source\n",
" image_html = f'<img src=\"data:image/jpeg;base64,{img_base64}\" />'\n",
" \n",
"\n",
" # Display the image by rendering the HTML\n",
" display(HTML(image_html))\n",
"\n",
"\n",
"# Image for QA\n",
"path = \"/Users/rlm/Desktop/Multimodal_Eval/qa/llm_strategies.jpeg\"\n",
"img_base64 = encode_image(path)\n",
@@ -99,19 +101,19 @@
"metadata": {},
"outputs": [],
"source": [
"chat = ChatOpenAI(model=\"gpt-4-vision-preview\",\n",
" max_tokens=1024)\n",
"chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
"msg = chat.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\":\"Based on the image, what is the difference in training strategy between a small and a large base model?\"},\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": \"Based on the image, what is the difference in training strategy between a small and a large base model?\",\n",
" },\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": f\"data:image/jpeg;base64,{img_base64}\"\n",
" },\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" },\n",
" ]\n",
" )\n",

View File

@@ -134,7 +134,7 @@
" name=\"langchain assistant\",\n",
" instructions=\"You are a personal math tutor. Write and run code to answer math questions.\",\n",
" tools=[{\"type\": \"code_interpreter\"}],\n",
" model=\"gpt-4-1106-preview\"\n",
" model=\"gpt-4-1106-preview\",\n",
")\n",
"output = interpreter_assistant.invoke({\"content\": \"What's 10 - 4 raised to the 2.7\"})\n",
"output"
@@ -184,7 +184,7 @@
" instructions=\"You are a personal math tutor. Write and run code to answer math questions. You can also search the internet.\",\n",
" tools=tools,\n",
" model=\"gpt-4-1106-preview\",\n",
" as_agent=True\n",
" as_agent=True,\n",
")"
]
},
@@ -241,7 +241,7 @@
" instructions=\"You are a personal math tutor. Write and run code to answer math questions.\",\n",
" tools=tools,\n",
" model=\"gpt-4-1106-preview\",\n",
" as_agent=True\n",
" as_agent=True,\n",
")"
]
},
@@ -254,6 +254,7 @@
"source": [
"from langchain.schema.agent import AgentFinish\n",
"\n",
"\n",
"def execute_agent(agent, tools, input):\n",
" tool_map = {tool.name: tool for tool in tools}\n",
" response = agent.invoke(input)\n",
@@ -262,9 +263,17 @@
" for action in response:\n",
" tool_output = tool_map[action.tool].invoke(action.tool_input)\n",
" print(action.tool, action.tool_input, tool_output, end=\"\\n\\n\")\n",
" tool_outputs.append({\"output\": tool_output, \"tool_call_id\": action.tool_call_id})\n",
" response = agent.invoke({\"tool_outputs\": tool_outputs, \"run_id\": action.run_id, \"thread_id\": action.thread_id})\n",
" \n",
" tool_outputs.append(\n",
" {\"output\": tool_output, \"tool_call_id\": action.tool_call_id}\n",
" )\n",
" response = agent.invoke(\n",
" {\n",
" \"tool_outputs\": tool_outputs,\n",
" \"run_id\": action.run_id,\n",
" \"thread_id\": action.thread_id,\n",
" }\n",
" )\n",
"\n",
" return response"
]
},
@@ -306,7 +315,9 @@
}
],
"source": [
"next_response = execute_agent(agent, tools, {\"content\": \"now add 17.241\", \"thread_id\": response.thread_id})\n",
"next_response = execute_agent(\n",
" agent, tools, {\"content\": \"now add 17.241\", \"thread_id\": response.thread_id}\n",
")\n",
"print(next_response.return_values[\"output\"])"
]
},
@@ -449,16 +460,22 @@
"from langchain.prompts import ChatPromptTemplate\n",
"from langchain.pydantic_v1 import BaseModel, Field\n",
"\n",
"\n",
"class GetCurrentWeather(BaseModel):\n",
" \"\"\"Get the current weather in a location.\"\"\"\n",
"\n",
" location: str = Field(description=\"The city and state, e.g. San Francisco, CA\")\n",
" unit: Literal[\"celsius\", \"fahrenheit\"] = Field(default=\"fahrenheit\", description=\"The temperature unit, default to fahrenheit\")\n",
" \n",
"prompt = ChatPromptTemplate.from_messages([\n",
" (\"system\", \"You are a helpful assistant\"),\n",
" (\"user\", \"{input}\")\n",
"])\n",
"model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\").bind(tools=[convert_pydantic_to_openai_tool(GetCurrentWeather)])\n",
" unit: Literal[\"celsius\", \"fahrenheit\"] = Field(\n",
" default=\"fahrenheit\", description=\"The temperature unit, default to fahrenheit\"\n",
" )\n",
"\n",
"\n",
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", \"You are a helpful assistant\"), (\"user\", \"{input}\")]\n",
")\n",
"model = ChatOpenAI(model=\"gpt-3.5-turbo-1106\").bind(\n",
" tools=[convert_pydantic_to_openai_tool(GetCurrentWeather)]\n",
")\n",
"chain = prompt | model | PydanticToolsParser(tools=[GetCurrentWeather])\n",
"\n",
"chain.invoke({\"input\": \"what's the weather in NYC, LA, and SF\"})"